diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 7967034..aa07197 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -14,8 +14,8 @@ cmd:text("waifu2x-benchmark") cmd:text("Options:") cmd:option("-dir", "./data/test", 'test image directory') -cmd:option("-model1_dir", "./models/anime_style_art", 'model1 directory') -cmd:option("-model2_dir", "./models/anime_style_art_rgb", 'model2 directory') +cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory') +cmd:option("-model2_dir", "", 'model2 directory (optional)') cmd:option("-method", "scale", '(scale|noise)') cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)") cmd:option("-color", "rgb", '(rgb|y)') @@ -83,32 +83,46 @@ local function benchmark(opt, x, input_func, model1, model2) t = sys.clock() if input:size(3) == ground_truth:size(3) then model1_output = reconstruct.image(model1, input) - model2_output = reconstruct.image(model2, input) + if model2 then + model2_output = reconstruct.image(model2, input) + end else model1_output = reconstruct.scale(model1, 2.0, input) - model2_output = reconstruct.scale(model2, 2.0, input) + if model2 then + model2_output = reconstruct.scale(model2, 2.0, input) + end end if opt.color == "y" then model1_mse = model1_mse + YMSE(ground_truth, model1_output) model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output) - model2_mse = model2_mse + YMSE(ground_truth, model2_output) - model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output) + if model2 then + model2_mse = model2_mse + YMSE(ground_truth, model2_output) + model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output) + end elseif opt.color == "rgb" then model1_mse = model1_mse + MSE(ground_truth, model1_output) model1_psnr = model1_psnr + PSNR(ground_truth, model1_output) - model2_mse = model2_mse + MSE(ground_truth, model2_output) - model2_psnr = model2_psnr + PSNR(ground_truth, model2_output) + if model2 then + model2_mse = model2_mse + MSE(ground_truth, model2_output) + model2_psnr = model2_psnr + PSNR(ground_truth, model2_output) + end else error("Unknown color: " .. opt.color) end - - io.stdout:write( - string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r", - i, #x, - model1_mse / i, model2_mse / i, - model1_psnr / i, model2_psnr / i - ) - ) + if model2 then + io.stdout:write( + string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r", + i, #x, + model1_mse / i, model2_mse / i, + model1_psnr / i, model2_psnr / i + )) + else + io.stdout:write( + string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r", + i, #x, + model1_mse / i, model1_psnr / i + )) + end io.stdout:flush() end io.stdout:write("\n") @@ -122,16 +136,34 @@ local function load_data(test_dir) end return test_x end - +function load_model(filename) + return torch.load(filename, "ascii") +end print(opt) if opt.method == "scale" then - local model1 = torch.load(path.join(opt.model1_dir, "scale2.0x_model.t7"), "ascii") - local model2 = torch.load(path.join(opt.model2_dir, "scale2.0x_model.t7"), "ascii") + local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7") + local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7") + local s1, model1 = pcall(load_model, f1) + local s2, model2 = pcall(load_model, f2) + if not s1 then + error("Load error: " .. f1) + end + if not s2 then + model2 = nil + end local test_x = load_data(opt.dir) benchmark(opt, test_x, transform_scale, model1, model2) elseif opt.method == "noise" then - local model1 = torch.load(path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii") - local model2 = torch.load(path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)), "ascii") + local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)) + local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)) + local s1, model1 = pcall(load_model, f1) + local s2, model2 = pcall(load_model, f2) + if not s1 then + error("Load error: " .. f1) + end + if not s2 then + model2 = nil + end local test_x = load_data(opt.dir) benchmark(opt, test_x, transform_jpeg, model1, model2) end