1
0
Fork 0
mirror of synced 2024-05-19 04:12:19 +12:00

Add -crop_size and -batch_size option to tools/benchmark.lua. Fix a bug in tta mode.

This commit is contained in:
nagadomi 2016-06-10 09:20:43 +09:00
parent b8ff8c6787
commit c16d0a07a2

View file

@ -34,7 +34,10 @@ cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Ca
cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
cmd:option("-thread", -1, 'number of CPU threads')
cmd:option("-tta", 0, 'tta')
cmd:option("-tta", 0, 'use tta')
cmd:option("-tta_level", 8, 'tta level')
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-batch_size", 1, 'batch_size')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -136,8 +139,14 @@ local function benchmark(opt, x, input_func, model1, model2)
local scale_f = reconstruct.scale
local image_f = reconstruct.image
if opt.tta then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
scale_f = function(model, scale, x, block_size, batch_size)
return reconstruct.scale_tta(model, opt.tta_level,
scale, x, block_size, batch_size)
end
image_f = function(model, x, block_size, batch_size)
return reconstruct.image_tta(model, opt.tta_level,
x, block_size, batch_size)
end
end
for i = 1, #x do
@ -149,14 +158,14 @@ local function benchmark(opt, x, input_func, model1, model2)
input = input_func(ground_truth, opt)
t = sys.clock()
if input:size(3) == ground_truth:size(3) then
model1_output = image_f(model1, input)
model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
if model2 then
model2_output = image_f(model2, input)
model2_output = image_f(model2, input, opt.crop_size, opt.batch_size)
end
else
model1_output = scale_f(model1, 2.0, input)
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
if model2 then
model2_output = scale_f(model2, 2.0, input)
model2_output = scale_f(model2, 2.0, input, opt.crop_size, opt.batch_size)
end
baseline_output = baseline_scale(input, opt.baseline_filter)
end