Don't run model2 benchmark when model2_dir is not specified
This commit is contained in:
parent
af1b9c604b
commit
7ac7923345
|
@ -14,8 +14,8 @@ cmd:text("waifu2x-benchmark")
|
|||
cmd:text("Options:")
|
||||
|
||||
cmd:option("-dir", "./data/test", 'test image directory')
|
||||
cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory')
|
||||
cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory')
|
||||
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
||||
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
||||
cmd:option("-method", "scale", '(scale|noise)')
|
||||
cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
|
||||
cmd:option("-color", "rgb", '(rgb|y)')
|
||||
|
@ -83,32 +83,46 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|||
t = sys.clock()
|
||||
if input:size(3) == ground_truth:size(3) then
|
||||
model1_output = reconstruct.image(model1, input)
|
||||
model2_output = reconstruct.image(model2, input)
|
||||
if model2 then
|
||||
model2_output = reconstruct.image(model2, input)
|
||||
end
|
||||
else
|
||||
model1_output = reconstruct.scale(model1, 2.0, input)
|
||||
model2_output = reconstruct.scale(model2, 2.0, input)
|
||||
if model2 then
|
||||
model2_output = reconstruct.scale(model2, 2.0, input)
|
||||
end
|
||||
end
|
||||
if opt.color == "y" then
|
||||
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
|
||||
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
|
||||
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
||||
end
|
||||
elseif opt.color == "rgb" then
|
||||
model1_mse = model1_mse + MSE(ground_truth, model1_output)
|
||||
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
|
||||
model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
||||
end
|
||||
else
|
||||
error("Unknown color: " .. opt.color)
|
||||
end
|
||||
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model2_mse / i,
|
||||
model1_psnr / i, model2_psnr / i
|
||||
)
|
||||
)
|
||||
if model2 then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model2_mse / i,
|
||||
model1_psnr / i, model2_psnr / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model1_psnr / i
|
||||
))
|
||||
end
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
|
@ -122,16 +136,34 @@ local function load_data(test_dir)
|
|||
end
|
||||
return test_x
|
||||
end
|
||||
|
||||
function load_model(filename)
|
||||
return torch.load(filename, "ascii")
|
||||
end
|
||||
print(opt)
|
||||
if opt.method == "scale" then
|
||||
local model1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii")
|
||||
local model2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii")
|
||||
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
|
||||
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
|
||||
local s1, model1 = pcall(load_model, f1)
|
||||
local s2, model2 = pcall(load_model, f2)
|
||||
if not s1 then
|
||||
error("Load error: " .. f1)
|
||||
end
|
||||
if not s2 then
|
||||
model2 = nil
|
||||
end
|
||||
local test_x = load_data(opt.dir)
|
||||
benchmark(opt, test_x, transform_scale, model1, model2)
|
||||
elseif opt.method == "noise" then
|
||||
local model1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
|
||||
local model2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii")
|
||||
local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
|
||||
local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
|
||||
local s1, model1 = pcall(load_model, f1)
|
||||
local s2, model2 = pcall(load_model, f2)
|
||||
if not s1 then
|
||||
error("Load error: " .. f1)
|
||||
end
|
||||
if not s2 then
|
||||
model2 = nil
|
||||
end
|
||||
local test_x = load_data(opt.dir)
|
||||
benchmark(opt, test_x, transform_jpeg, model1, model2)
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue