From 5e222a39815734446696bdf6b4b055a0a15bf052 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 2 Jun 2016 10:15:54 +0900 Subject: [PATCH] Add -tta and -resize_blur option to benchmark --- tools/benchmark.lua | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 6c5de4c..1375a61 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -19,6 +19,7 @@ cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory') cmd:option("-model2_dir", "", 'model2 directory (optional)') cmd:option("-method", "scale", '(scale|noise)') cmd:option("-filter", "Catrom", "downscaling filter (Box|Lanczos|Catrom(Bicubic))") +cmd:option("-resize_blur", 1.0, 'blur parameter for resize') cmd:option("-color", "y", '(rgb|y)') cmd:option("-noise_level", 1, 'model noise level') cmd:option("-jpeg_quality", 75, 'jpeg quality') @@ -34,6 +35,7 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt') cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option') cmd:option("-thread", -1, 'number of CPU threads') +cmd:option("-tta", 0, 'tta') local function to_bool(settings, name) if settings[name] == 1 then @@ -50,6 +52,7 @@ if cudnn then end to_bool(opt, "gamma_correction") to_bool(opt, "save_all") +to_bool(opt, "tta") if opt.save_all then opt.save_image = true opt.save_info = true @@ -123,12 +126,12 @@ local function transform_scale(x, opt) return iproc.scale_with_gamma22(x, x:size(3) * 0.5, x:size(2) * 0.5, - opt.filter) + opt.filter, opt.resize_blur) else return iproc.scale(x, x:size(3) * 0.5, x:size(2) * 0.5, - opt.filter) + opt.filter, opt.resize_blur) end end @@ -139,6 +142,12 @@ local function benchmark(opt, x, input_func, model1, model2) local model1_psnr = 0 local model2_psnr = 0 local baseline_psnr = 0 + local scale_f = reconstruct.scale + local image_f = reconstruct.image + if opt.tta then + scale_f = reconstruct.scale_tta + image_f = reconstruct.image_tta + end for i = 1, #x do local ground_truth = x[i].image @@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2) input = input_func(ground_truth, opt) t = sys.clock() if input:size(3) == ground_truth:size(3) then - model1_output = reconstruct.image(model1, input) + model1_output = image_f(model1, input) if model2 then - model2_output = reconstruct.image(model2, input) + model2_output = image_f(model2, input) end else - model1_output = reconstruct.scale(model1, 2.0, input) + model1_output = scale_f(model1, 2.0, input) if model2 then - model2_output = reconstruct.scale(model2, 2.0, input) + model2_output = scale_f(model2, 2.0, input) end baseline_output = baseline_scale(input, opt.baseline_filter) end