Add -tta and -resize_blur option to benchmark
This commit is contained in:
parent
abae4cb855
commit
5e222a3981
|
@ -19,6 +19,7 @@ cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
|||
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
||||
cmd:option("-method", "scale", '(scale|noise)')
|
||||
cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))")
|
||||
cmd:option("-resize_blur", 1.0, 'blur parameter for resize')
|
||||
cmd:option("-color", "y", '(rgb|y)')
|
||||
cmd:option("-noise_level", 1, 'model noise level')
|
||||
cmd:option("-jpeg_quality", 75, 'jpeg quality')
|
||||
|
@ -34,6 +35,7 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca
|
|||
cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
|
||||
cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-tta", 0, 'tta')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -50,6 +52,7 @@ if cudnn then
|
|||
end
|
||||
to_bool(opt, "gamma_correction")
|
||||
to_bool(opt, "save_all")
|
||||
to_bool(opt, "tta")
|
||||
if opt.save_all then
|
||||
opt.save_image = true
|
||||
opt.save_info = true
|
||||
|
@ -123,12 +126,12 @@ local function transform_scale(x, opt)
|
|||
return iproc.scale_with_gamma22(x,
|
||||
x:size(3) * 0.5,
|
||||
x:size(2) * 0.5,
|
||||
opt.filter)
|
||||
opt.filter, opt.resize_blur)
|
||||
else
|
||||
return iproc.scale(x,
|
||||
x:size(3) * 0.5,
|
||||
x:size(2) * 0.5,
|
||||
opt.filter)
|
||||
opt.filter, opt.resize_blur)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -139,6 +142,12 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|||
local model1_psnr = 0
|
||||
local model2_psnr = 0
|
||||
local baseline_psnr = 0
|
||||
local scale_f = reconstruct.scale
|
||||
local image_f = reconstruct.image
|
||||
if opt.tta then
|
||||
scale_f = reconstruct.scale_tta
|
||||
image_f = reconstruct.image_tta
|
||||
end
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i].image
|
||||
|
@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|||
input = input_func(ground_truth, opt)
|
||||
t = sys.clock()
|
||||
if input:size(3) == ground_truth:size(3) then
|
||||
model1_output = reconstruct.image(model1, input)
|
||||
model1_output = image_f(model1, input)
|
||||
if model2 then
|
||||
model2_output = reconstruct.image(model2, input)
|
||||
model2_output = image_f(model2, input)
|
||||
end
|
||||
else
|
||||
model1_output = reconstruct.scale(model1, 2.0, input)
|
||||
model1_output = scale_f(model1, 2.0, input)
|
||||
if model2 then
|
||||
model2_output = reconstruct.scale(model2, 2.0, input)
|
||||
model2_output = scale_f(model2, 2.0, input)
|
||||
end
|
||||
baseline_output = baseline_scale(input, opt.baseline_filter)
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue