1
0
Fork 0
mirror of synced 2024-06-13 08:24:30 +12:00

Add -tta option

The TTA mode:
- 8x slower than normal mode
- improves PSNR +0.1
This commit is contained in:
nagadomi 2015-11-09 04:01:28 +09:00
parent 4322b63750
commit b335f3a9ad
2 changed files with 110 additions and 15 deletions

View file

@ -206,5 +206,64 @@ function reconstruct.scale(model, scale, x, block_size)
reconstruct.offset_size(model), block_size)
end
end
local function tta(f, model, x, block_size)
local average = nil
local offset = reconstruct.offset_size(model)
for i = 1, 4 do
local flip_f, iflip_f
if i == 1 then
flip_f = function (a) return a end
iflip_f = function (a) return a end
elseif i == 2 then
flip_f = image.vflip
iflip_f = image.vflip
elseif i == 3 then
flip_f = image.hflip
iflip_f = image.hflip
elseif i == 4 then
flip_f = function (a) return image.hflip(image.vflip(a)) end
iflip_f = function (a) return image.vflip(image.hflip(a)) end
end
for j = 1, 2 do
local tr_f, itr_f
if j == 1 then
tr_f = function (a) return a end
itr_f = function (a) return a end
elseif j == 2 then
tr_f = function(a) return a:transpose(2, 3):contiguous() end
itr_f = function(a) return a:transpose(2, 3):contiguous() end
end
local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
offset, block_size)))
if not average then
average = out
else
average:add(out)
end
end
end
return average:div(8.0)
end
function reconstruct.image_tta(model, x, block_size)
if reconstruct.is_rgb(model) then
return tta(reconstruct.image_rgb, model, x, block_size)
else
return tta(reconstruct.image_y, model, x, block_size)
end
end
function reconstruct.scale_tta(model, scale, x, block_size)
if reconstruct.is_rgb(model) then
local f = function (model, x, offset, block_size)
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
end
return tta(f, model, x, block_size)
else
local f = function (model, x, offset, block_size)
return reconstruct.scale_y(model, scale, x, offset, block_size)
end
return tta(f, model, x, block_size)
end
end
return reconstruct

View file

@ -13,6 +13,14 @@ local function convert_image(opt)
local x, alpha = image_loader.load_float(opt.i)
local new_x = nil
local t = sys.clock()
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
else
scale_f = reconstruct.scale
image_f = reconstruct.image
end
if opt.o == "(auto)" then
local name = path.basename(opt.i)
local e = path.extension(name)
@ -25,14 +33,14 @@ local function convert_image(opt)
if not model then
error("Load Error: " .. model_path)
end
new_x = reconstruct.image(model, x, opt.crop_size)
new_x = image_f(model, x, opt.crop_size)
elseif opt.m == "scale" then
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
local model = torch.load(model_path, "ascii")
if not model then
error("Load Error: " .. model_path)
end
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
new_x = scale_f(model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" then
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
local noise_model = torch.load(noise_model_path, "ascii")
@ -45,8 +53,8 @@ local function convert_image(opt)
if not scale_model then
error("Load Error: " .. scale_model_path)
end
x = reconstruct.image(noise_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
x = image_f(noise_model, x)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
@ -54,25 +62,52 @@ local function convert_image(opt)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
local noise1_model, noise2_model, scale_model
local model_path, noise1_model, noise2_model, scale_model
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
else
scale_f = reconstruct.scale
image_f = reconstruct.image
end
if opt.m == "scale" then
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
scale_model = torch.load(model_path, "ascii")
if not scale_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 1 then
local model_path = path.join(opt.model_dir, "noise1_model.t7")
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 2 then
local model_path = path.join(opt.model_dir, "noise2_model.t7")
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise_scale" then
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
scale_model = torch.load(model_path, "ascii")
if not scale_model then
error("Load Error: " .. model_path)
end
if opt.noise_level == 1 then
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.noise_level == 2 then
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
end
end
local fp = io.open(opt.l)
if not fp then
@ -89,17 +124,17 @@ local function convert_frames(opt)
local x, alpha = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, opt.crop_size)
new_x = image_f(noise1_model, x, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x)
new_x = image_func(noise2_model, x)
elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
x = image_f(noise1_model, x)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x)
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
x = image_f(noise2_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
@ -139,6 +174,7 @@ local function waifu2x()
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-resume", 0, "skip existing files (0|1)")
cmd:option("-thread", -1, "number of CPU threads")
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
local opt = cmd:parse(arg)
if opt.thread > 0 then