diff --git a/tools/benchmark.lua b/tools/benchmark.lua index eb0a087..ffab09e 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -25,18 +25,29 @@ cmd:option("-jpeg_times", 1, 'jpeg compression times') cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times') cmd:option("-range_bug", 0, 'Reproducing the dynamic range bug that is caused by MATLAB\'s rgb2ycbcr(1|0)') cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) (0|1)') +cmd:option("-save_image", 0, 'save converted images') +cmd:option("-save_baseline_image", 0, 'save baseline images') +cmd:option("-output_dir", "./", 'output directroy') +cmd:option("-show_progress", 1, 'show progressbar') +cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Catrom(Bicubic))') +local function to_bool(settings, name) + if settings[name] == 1 then + settings[name] = true + else + settings[name] = false + end +end local opt = cmd:parse(arg) torch.setdefaulttensortype('torch.FloatTensor') if cudnn then cudnn.fastest = true cudnn.benchmark = false end -if opt.gamma_correction == 1 then - opt.gamma_correction = true -else - opt.gamma_correction = false -end +to_bool(opt, "gamma_correction") +to_bool(opt, "save_image") +to_bool(opt, "save_baseline_image") +to_bool(opt, "show_progress") local function rgb2y_matlab(x) local y = torch.Tensor(1, x:size(2), x:size(3)):zero() @@ -115,7 +126,9 @@ local function benchmark(opt, x, input_func, model1, model2) local baseline_psnr = 0 for i = 1, #x do - local ground_truth = x[i] + local ground_truth = x[i].image + local basename = x[i].basename + local input, model1_output, model2_output, baseline_output input = input_func(ground_truth, opt) @@ -130,7 +143,7 @@ local function benchmark(opt, x, input_func, model1, model2) if model2 then model2_output = reconstruct.scale(model2, 2.0, input) end - baseline_output = baseline_scale(input, opt.filter) + baseline_output = baseline_scale(input, opt.baseline_filter) end model1_mse = model1_mse + MSE(ground_truth, model1_output, opt.color) model1_psnr = model1_psnr + PSNR(ground_truth, model1_output, opt.color) @@ -142,41 +155,57 @@ local function benchmark(opt, x, input_func, model1, model2) baseline_mse = baseline_mse + MSE(ground_truth, baseline_output, opt.color) baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output, opt.color) end - if model2 then - if baseline_output then - io.stdout:write( - string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r", - i, #x, - math.sqrt(baseline_mse / i), - math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), - baseline_psnr / i, - model1_psnr / i, model2_psnr / i - )) - else - io.stdout:write( - string.format("%d/%d; model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r", - i, #x, - math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), - model1_psnr / i, model2_psnr / i - )) + if opt.save_image then + if opt.save_baseline_image and baseline_output then + image.save(path.join(opt.output_dir, string.format("%s_baseline.png", basename)), + baseline_output) end - else - if baseline_output then - io.stdout:write( - string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r", - i, #x, - math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), - baseline_psnr / i, model1_psnr / i - )) - else - io.stdout:write( - string.format("%d/%d; model1_rmse=%f, model1_psnr=%f \r", - i, #x, - math.sqrt(model1_mse / i), model1_psnr / i - )) + if model1_output then + image.save(path.join(opt.output_dir, string.format("%s_model1.png", basename)), + model1_output) + end + if model2_output then + image.save(path.join(opt.output_dir, string.format("%s_model2.png", basename)), + model2_output) end end - io.stdout:flush() + if opt.show_progress or i == #x then + if model2 then + if baseline_output then + io.stdout:write( + string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r", + i, #x, + math.sqrt(baseline_mse / i), + math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), + baseline_psnr / i, + model1_psnr / i, model2_psnr / i + )) + else + io.stdout:write( + string.format("%d/%d; model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r", + i, #x, + math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), + model1_psnr / i, model2_psnr / i + )) + end + else + if baseline_output then + io.stdout:write( + string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r", + i, #x, + math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), + baseline_psnr / i, model1_psnr / i + )) + else + io.stdout:write( + string.format("%d/%d; model1_rmse=%f, model1_psnr=%f \r", + i, #x, + math.sqrt(model1_mse / i), model1_psnr / i + )) + end + end + io.stdout:flush() + end end io.stdout:write("\n") end @@ -184,15 +213,23 @@ local function load_data(test_dir) local test_x = {} local files = dir.getfiles(test_dir, "*.*") for i = 1, #files do - table.insert(test_x, iproc.crop_mod4(image_loader.load_float(files[i]))) - xlua.progress(i, #files) + local name = path.basename(files[i]) + local e = path.extension(name) + local base = name:sub(0, name:len() - e:len()) + table.insert(test_x, {image = iproc.crop_mod4(image_loader.load_float(files[i])), + basename = base}) + if opt.show_progress then + xlua.progress(i, #files) + end end return test_x end function load_model(filename) return torch.load(filename, "ascii") end -print(opt) +if opt.show_progress then + print(opt) +end if opt.method == "scale" then local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7") local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")