diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 4cb50f4..233e542 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -34,7 +34,10 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt') cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option') cmd:option("-thread", -1, 'number of CPU threads') -cmd:option("-tta", 0, 'tta') +cmd:option("-tta", 0, 'use tta') +cmd:option("-tta_level", 8, 'tta level') +cmd:option("-crop_size", 128, 'patch size per process') +cmd:option("-batch_size", 1, 'batch_size') local function to_bool(settings, name) if settings[name] == 1 then @@ -136,8 +139,14 @@ local function benchmark(opt, x, input_func, model1, model2) local scale_f = reconstruct.scale local image_f = reconstruct.image if opt.tta then - scale_f = reconstruct.scale_tta - image_f = reconstruct.image_tta + scale_f = function(model, scale, x, block_size, batch_size) + return reconstruct.scale_tta(model, opt.tta_level, + scale, x, block_size, batch_size) + end + image_f = function(model, x, block_size, batch_size) + return reconstruct.image_tta(model, opt.tta_level, + x, block_size, batch_size) + end end for i = 1, #x do @@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2) input = input_func(ground_truth, opt) t = sys.clock() if input:size(3) == ground_truth:size(3) then - model1_output = image_f(model1, input) + model1_output = image_f(model1, input, opt.crop_size, opt.batch_size) if model2 then - model2_output = image_f(model2, input) + model2_output = image_f(model2, input, opt.crop_size, opt.batch_size) end else - model1_output = scale_f(model1, 2.0, input) + model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size) if model2 then - model2_output = scale_f(model2, 2.0, input) + model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size) end baseline_output = baseline_scale(input, opt.baseline_filter) end