Add -tta option
The TTA mode: - 8x slower than normal mode - improves PSNR +0.1
This commit is contained in:
parent
4322b63750
commit
b335f3a9ad
|
@ -206,5 +206,64 @@ function reconstruct.scale(model, scale, x, block_size)
|
|||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
end
|
||||
local function tta(f, model, x, block_size)
|
||||
local average = nil
|
||||
local offset = reconstruct.offset_size(model)
|
||||
for i = 1, 4 do
|
||||
local flip_f, iflip_f
|
||||
if i == 1 then
|
||||
flip_f = function (a) return a end
|
||||
iflip_f = function (a) return a end
|
||||
elseif i == 2 then
|
||||
flip_f = image.vflip
|
||||
iflip_f = image.vflip
|
||||
elseif i == 3 then
|
||||
flip_f = image.hflip
|
||||
iflip_f = image.hflip
|
||||
elseif i == 4 then
|
||||
flip_f = function (a) return image.hflip(image.vflip(a)) end
|
||||
iflip_f = function (a) return image.vflip(image.hflip(a)) end
|
||||
end
|
||||
for j = 1, 2 do
|
||||
local tr_f, itr_f
|
||||
if j == 1 then
|
||||
tr_f = function (a) return a end
|
||||
itr_f = function (a) return a end
|
||||
elseif j == 2 then
|
||||
tr_f = function(a) return a:transpose(2, 3):contiguous() end
|
||||
itr_f = function(a) return a:transpose(2, 3):contiguous() end
|
||||
end
|
||||
local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
|
||||
offset, block_size)))
|
||||
if not average then
|
||||
average = out
|
||||
else
|
||||
average:add(out)
|
||||
end
|
||||
end
|
||||
end
|
||||
return average:div(8.0)
|
||||
end
|
||||
function reconstruct.image_tta(model, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
return tta(reconstruct.image_rgb, model, x, block_size)
|
||||
else
|
||||
return tta(reconstruct.image_y, model, x, block_size)
|
||||
end
|
||||
end
|
||||
function reconstruct.scale_tta(model, scale, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
local f = function (model, x, offset, block_size)
|
||||
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||
end
|
||||
return tta(f, model, x, block_size)
|
||||
|
||||
else
|
||||
local f = function (model, x, offset, block_size)
|
||||
return reconstruct.scale_y(model, scale, x, offset, block_size)
|
||||
end
|
||||
return tta(f, model, x, block_size)
|
||||
end
|
||||
end
|
||||
|
||||
return reconstruct
|
||||
|
|
66
waifu2x.lua
66
waifu2x.lua
|
@ -13,6 +13,14 @@ local function convert_image(opt)
|
|||
local x, alpha = image_loader.load_float(opt.i)
|
||||
local new_x = nil
|
||||
local t = sys.clock()
|
||||
local scale_f, image_f
|
||||
if opt.tta == 1 then
|
||||
scale_f = reconstruct.scale_tta
|
||||
image_f = reconstruct.image_tta
|
||||
else
|
||||
scale_f = reconstruct.scale
|
||||
image_f = reconstruct.image
|
||||
end
|
||||
if opt.o == "(auto)" then
|
||||
local name = path.basename(opt.i)
|
||||
local e = path.extension(name)
|
||||
|
@ -25,14 +33,14 @@ local function convert_image(opt)
|
|||
if not model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
new_x = reconstruct.image(model, x, opt.crop_size)
|
||||
new_x = image_f(model, x, opt.crop_size)
|
||||
elseif opt.m == "scale" then
|
||||
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
local model = torch.load(model_path, "ascii")
|
||||
if not model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
new_x = reconstruct.scale(model, opt.scale, x, opt.crop_size)
|
||||
new_x = scale_f(model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" then
|
||||
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
|
||||
local noise_model = torch.load(noise_model_path, "ascii")
|
||||
|
@ -45,8 +53,8 @@ local function convert_image(opt)
|
|||
if not scale_model then
|
||||
error("Load Error: " .. scale_model_path)
|
||||
end
|
||||
x = reconstruct.image(noise_model, x)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
x = image_f(noise_model, x)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
|
@ -54,25 +62,52 @@ local function convert_image(opt)
|
|||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||
end
|
||||
local function convert_frames(opt)
|
||||
local noise1_model, noise2_model, scale_model
|
||||
local model_path, noise1_model, noise2_model, scale_model
|
||||
local scale_f, image_f
|
||||
if opt.tta == 1 then
|
||||
scale_f = reconstruct.scale_tta
|
||||
image_f = reconstruct.image_tta
|
||||
else
|
||||
scale_f = reconstruct.scale
|
||||
image_f = reconstruct.image
|
||||
end
|
||||
if opt.m == "scale" then
|
||||
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
scale_model = torch.load(model_path, "ascii")
|
||||
if not scale_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise" and opt.noise_level == 1 then
|
||||
local model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
noise1_model = torch.load(model_path, "ascii")
|
||||
if not noise1_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
local model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
noise2_model = torch.load(model_path, "ascii")
|
||||
if not noise2_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise_scale" then
|
||||
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
scale_model = torch.load(model_path, "ascii")
|
||||
if not scale_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
if opt.noise_level == 1 then
|
||||
model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
noise1_model = torch.load(model_path, "ascii")
|
||||
if not noise1_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.noise_level == 2 then
|
||||
model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
noise2_model = torch.load(model_path, "ascii")
|
||||
if not noise2_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
local fp = io.open(opt.l)
|
||||
if not fp then
|
||||
|
@ -89,17 +124,17 @@ local function convert_frames(opt)
|
|||
local x, alpha = image_loader.load_float(lines[i])
|
||||
local new_x = nil
|
||||
if opt.m == "noise" and opt.noise_level == 1 then
|
||||
new_x = reconstruct.image(noise1_model, x, opt.crop_size)
|
||||
new_x = image_f(noise1_model, x, opt.crop_size)
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
new_x = reconstruct.image(noise2_model, x)
|
||||
new_x = image_func(noise2_model, x)
|
||||
elseif opt.m == "scale" then
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
||||
x = reconstruct.image(noise1_model, x)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
x = image_f(noise1_model, x)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
||||
x = reconstruct.image(noise2_model, x)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, opt.crop_size)
|
||||
x = image_f(noise2_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
|
@ -139,6 +174,7 @@ local function waifu2x()
|
|||
cmd:option("-crop_size", 128, 'patch size per process')
|
||||
cmd:option("-resume", 0, "skip existing files (0|1)")
|
||||
cmd:option("-thread", -1, "number of CPU threads")
|
||||
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if opt.thread > 0 then
|
||||
|
|
Loading…
Reference in a new issue