1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Reduce memory usage in benchmark

This commit is contained in:
nagadomi 2017-01-04 19:53:17 +09:00
parent 43a9b58fcb
commit 02cf265d48

View file

@ -227,12 +227,15 @@ local function benchmark(opt, x, model1, model2)
end
for i = 1, #x do
if i % 10 == 0 then
collectgarbage()
end
local basename = x[i].basename
local input, model1_output, model2_output, baseline_output, ground_truth
if opt.method == "scale" then
input = transform_scale(x[i].y, opt)
ground_truth = x[i].y
input = transform_scale(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
@ -250,8 +253,8 @@ local function benchmark(opt, x, model1, model2)
end
baseline_output = baseline_scale(input, opt.baseline_filter)
elseif opt.method == "scale4" then
input = transform_scale4(x[i].y, opt)
ground_truth = x[i].y
input = transform_scale4(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then -- run cuDNN benchmark first
model1_output = scale_f(model1, 2.0, input, opt.crop_size, opt.batch_size)
if model2 then
@ -270,8 +273,8 @@ local function benchmark(opt, x, model1, model2)
end
baseline_output = baseline_scale4(input, opt.baseline_filter)
elseif opt.method == "noise" then
input = transform_jpeg(x[i].y, opt)
ground_truth = x[i].y
input = transform_jpeg(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then
model1_output = image_f(model1, input, opt.crop_size, opt.batch_size)
@ -289,8 +292,8 @@ local function benchmark(opt, x, model1, model2)
end
baseline_output = input
elseif opt.method == "noise_scale" then
input = transform_scale_jpeg(x[i].y, opt)
ground_truth = x[i].y
input = transform_scale_jpeg(iproc.byte2float(x[i].y), opt)
ground_truth = iproc.byte2float(x[i].y)
if opt.force_cudnn and i == 1 then
if model1.noise_scale_model then
@ -355,8 +358,8 @@ local function benchmark(opt, x, model1, model2)
end
baseline_output = baseline_scale(input, opt.baseline_filter)
elseif opt.method == "user" then
input = x[i].x
ground_truth = x[i].y
input = iproc.byte2float(x[i].x)
ground_truth = iproc.byte2float(x[i].y)
local y_scale = ground_truth:size(2) / input:size(2)
if y_scale > 1 then
if opt.force_cudnn and i == 1 then
@ -390,8 +393,8 @@ local function benchmark(opt, x, model1, model2)
end
end
elseif opt.method == "diff" then
input = x[i].x
ground_truth = x[i].y
input = iproc.byte2float(x[i].x)
ground_truth = iproc.byte2float(x[i].y)
model1_output = input
end
if opt.border > 0 then
@ -521,7 +524,7 @@ local function load_data_from_dir(test_dir)
local name = path.basename(files[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
local img = image_loader.load_float(files[i])
local img = image_loader.load_byte(files[i])
if img then
table.insert(test_x, {y = iproc.crop_mod4(img),
basename = base})
@ -529,6 +532,9 @@ local function load_data_from_dir(test_dir)
if opt.show_progress then
xlua.progress(i, #files)
end
if i % 10 == 0 then
collectgarbage()
end
end
return test_x
end
@ -539,7 +545,7 @@ local function load_data_from_file(test_file)
local name = path.basename(files[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
local img = image_loader.load_float(files[i])
local img = image_loader.load_byte(files[i])
if img then
table.insert(test_x, {y = iproc.crop_mod4(img),
basename = base})
@ -547,6 +553,9 @@ local function load_data_from_file(test_file)
if opt.show_progress then
xlua.progress(i, #files)
end
if i % 10 == 0 then
collectgarbage()
end
end
return test_x
end
@ -592,8 +601,8 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
end
for i = 1, #y_files do
local key = get_basename(y_files[i])
local x = image_loader.load_float(basename_db[key].x)
local y = image_loader.load_float(basename_db[key].y)
local x = image_loader.load_byte(basename_db[key].x)
local y = image_loader.load_byte(basename_db[key].y)
if x and y then
table.insert(test, {y = y,
x = x,
@ -602,6 +611,9 @@ local function load_user_data(y_dir, y_file, x_dir, x_file)
if opt.show_progress then
xlua.progress(i, #y_files)
end
if i % 10 == 0 then
collectgarbage()
end
end
return test
end