1
0
Fork 0
mirror of synced 2024-05-19 20:32:22 +12:00

Don't run model2 benchmark when model2_dir is not specified

This commit is contained in:
nagadomi 2015-11-09 06:23:57 +09:00
parent af1b9c604b
commit 7ac7923345

View file

@ -14,8 +14,8 @@ cmd:text("waifu2x-benchmark")
cmd:text("Options:")
cmd:option("-dir", "./data/test", 'test image directory')
cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory')
cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory')
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
cmd:option("-model2_dir", "", 'model2 directory (optional)')
cmd:option("-method", "scale", '(scale|noise)')
cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
cmd:option("-color", "rgb", '(rgb|y)')
@ -83,32 +83,46 @@ local function benchmark(opt, x, input_func, model1, model2)
t = sys.clock()
if input:size(3) == ground_truth:size(3) then
model1_output = reconstruct.image(model1, input)
model2_output = reconstruct.image(model2, input)
if model2 then
model2_output = reconstruct.image(model2, input)
end
else
model1_output = reconstruct.scale(model1, 2.0, input)
model2_output = reconstruct.scale(model2, 2.0, input)
if model2 then
model2_output = reconstruct.scale(model2, 2.0, input)
end
end
if opt.color == "y" then
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
if model2 then
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
end
elseif opt.color == "rgb" then
model1_mse = model1_mse + MSE(ground_truth, model1_output)
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
model2_mse = model2_mse + MSE(ground_truth, model2_output)
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
if model2 then
model2_mse = model2_mse + MSE(ground_truth, model2_output)
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
end
else
error("Unknown color: " .. opt.color)
end
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
model1_mse / i, model2_mse / i,
model1_psnr / i, model2_psnr / i
)
)
if model2 then
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
model1_mse / i, model2_mse / i,
model1_psnr / i, model2_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
i, #x,
model1_mse / i, model1_psnr / i
))
end
io.stdout:flush()
end
io.stdout:write("\n")
@ -122,16 +136,34 @@ local function load_data(test_dir)
end
return test_x
end
function load_model(filename)
return torch.load(filename, "ascii")
end
print(opt)
if opt.method == "scale" then
local model1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii")
local model2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii")
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
local s1, model1 = pcall(load_model, f1)
local s2, model2 = pcall(load_model, f2)
if not s1 then
error("Load error: " .. f1)
end
if not s2 then
model2 = nil
end
local test_x = load_data(opt.dir)
benchmark(opt, test_x, transform_scale, model1, model2)
elseif opt.method == "noise" then
local model1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
local model2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
local s1, model1 = pcall(load_model, f1)
local s2, model2 = pcall(load_model, f2)
if not s1 then
error("Load error: " .. f1)
end
if not s2 then
model2 = nil
end
local test_x = load_data(opt.dir)
benchmark(opt, test_x, transform_jpeg, model1, model2)
end