1
0
Fork 0
mirror of synced 2024-05-19 04:12:19 +12:00

Add support for the test filelist in benchmark

This commit is contained in:
nagadomi 2016-08-21 06:54:23 +09:00
parent f72d756172
commit bfb67e61f4

View file

@ -15,6 +15,7 @@ cmd:text("waifu2x-benchmark")
cmd:text("Options:")
cmd:option("-dir", "./data/test", 'test image directory')
cmd:option("-file", "", 'test image file list')
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
cmd:option("-model2_dir", "", 'model2 directory (optional)')
cmd:option("-method", "scale", '(scale|noise|noise_scale|user)')
@ -43,6 +44,8 @@ cmd:option("-yuv420", 0, 'use yuv420 jpeg')
cmd:option("-name", "", 'model name for user method')
cmd:option("-x_dir", "", 'input image for user method')
cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be the same as x_dir')
cmd:option("-x_file", "", 'input image for user method')
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -431,7 +434,7 @@ local function benchmark(opt, x, model1, model2)
end
io.stdout:write("\n")
end
local function load_data(test_dir)
local function load_data_from_dir(test_dir)
local test_x = {}
local files = dir.getfiles(test_dir, "*.*")
for i = 1, #files do
@ -449,16 +452,47 @@ local function load_data(test_dir)
end
return test_x
end
local function load_data_from_file(test_file)
local test_x = {}
local files = utils.split(file.read(test_file), "\n")
for i = 1, #files do
local name = path.basename(files[i])
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
local img = image_loader.load_float(files[i])
if img then
table.insert(test_x, {y = iproc.crop_mod4(img),
basename = base})
end
if opt.show_progress then
xlua.progress(i, #files)
end
end
return test_x
end
local function get_basename(f)
local name = path.basename(f)
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
return base
end
local function load_user_data(y_dir, x_dir)
local function load_user_data(y_dir, y_file, x_dir, x_file)
local test = {}
local y_files = dir.getfiles(y_dir, "*.*")
local x_files = dir.getfiles(x_dir, "*.*")
local y_files
local x_files
if y_file:len() > 0 then
print(y_file)
y_files = utils.split(file.read(y_file), "\n")
else
y_files = dir.getfiles(y_dir, "*.*")
end
if x_file:len() > 0 then
x_files = utils.split(file.read(x_file), "\n")
else
x_files = dir.getfiles(x_dir, "*.*")
end
local basename_db = {}
for i = 1, #y_files do
basename_db[get_basename(y_files[i])] = {y = y_files[i]}
@ -535,7 +569,12 @@ if opt.method == "scale" then
if not s2 then
model2 = nil
end
local test_x = load_data(opt.dir)
local test_x
if opt.file:len() > 0 then
test_x = load_data_from_file(opt.file)
else
test_x = load_data_from_dir(opt.dir)
end
benchmark(opt, test_x, model1, model2)
elseif opt.method == "noise" then
local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
@ -548,7 +587,12 @@ elseif opt.method == "noise" then
if not s2 then
model2 = nil
end
local test_x = load_data(opt.dir)
local test_x
if opt.file:len() > 0 then
test_x = load_data_from_file(opt.file)
else
test_x = load_data_from_dir(opt.dir)
end
benchmark(opt, test_x, model1, model2)
elseif opt.method == "noise_scale" then
local model2 = nil
@ -556,7 +600,12 @@ elseif opt.method == "noise_scale" then
if opt.model2_dir:len() > 0 then
model2 = load_noise_scale_model(opt.model2_dir, opt.noise_level, opt.force_cudnn)
end
local test_x = load_data(opt.dir)
local test_x
if opt.file:len() > 0 then
test_x = load_data_from_file(opt.file)
else
test_x = load_data_from_dir(opt.dir)
end
benchmark(opt, test_x, model1, model2)
elseif opt.method == "user" then
local f1 = path.join(opt.model1_dir, string.format("%s_model.t7", opt.name))
@ -569,6 +618,6 @@ elseif opt.method == "user" then
if not s2 then
model2 = nil
end
local test = load_user_data(opt.y_dir, opt.x_dir)
local test = load_user_data(opt.y_dir, opt.y_file, opt.x_dir, opt.x_file)
benchmark(opt, test, model1, model2)
end