From d31fbe9bb17b871a586212b188c22fc5c323dc51 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Tue, 12 Apr 2016 02:26:30 +0900 Subject: [PATCH] Improve benchmark script --- tools/benchmark.lua | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/tools/benchmark.lua b/tools/benchmark.lua index ffab09e..93b4cf1 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -7,6 +7,7 @@ local iproc = require 'iproc' local reconstruct = require 'reconstruct' local image_loader = require 'image_loader' local gm = require 'graphicsmagick' +local cjson = require 'cjson' local cmd = torch.CmdLine() cmd:text() @@ -30,6 +31,8 @@ cmd:option("-save_baseline_image", 0, 'save baseline images') cmd:option("-output_dir", "./", 'output directroy') cmd:option("-show_progress", 1, 'show progressbar') cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Catrom(Bicubic))') +cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt') +cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option') local function to_bool(settings, name) if settings[name] == 1 then @@ -45,8 +48,16 @@ if cudnn then cudnn.benchmark = false end to_bool(opt, "gamma_correction") -to_bool(opt, "save_image") -to_bool(opt, "save_baseline_image") +to_bool(opt, "save_all") +if opt.save_all then + opt.save_image = true + opt.save_info = true + opt.save_baseline_image = true +else + to_bool(opt, "save_image") + to_bool(opt, "save_info") + to_bool(opt, "save_baseline_image") +end to_bool(opt, "show_progress") local function rgb2y_matlab(x) @@ -207,6 +218,23 @@ local function benchmark(opt, x, input_func, model1, model2) io.stdout:flush() end end + if opt.save_info then + local fp = io.open(path.join(opt.output_dir, "benchmark.txt"), "w") + fp:write("options : " .. cjson.encode(opt) .. "\n") + if baseline_psnr > 0 then + fp:write(string.format("baseline: RMSE = %.3f, PSNR = %.3f\n", + math.sqrt(baseline_mse / #x), baseline_psnr / #x)) + end + if model1_psnr > 0 then + fp:write(string.format("model1 : RMSE = %.3f, PSNR = %.3f\n", + math.sqrt(model1_mse / #x), model1_psnr / #x)) + end + if model2_psnr > 0 then + fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f\n", + math.sqrt(model2_mse / #x), model2_psnr / #x)) + end + fp:close() + end io.stdout:write("\n") end local function load_data(test_dir) @@ -216,8 +244,11 @@ local function load_data(test_dir) local name = path.basename(files[i]) local e = path.extension(name) local base = name:sub(0, name:len() - e:len()) - table.insert(test_x, {image = iproc.crop_mod4(image_loader.load_float(files[i])), - basename = base}) + local img = image_loader.load_float(files[i]) + if img then + table.insert(test_x, {image = iproc.crop_mod4(img), + basename = base}) + end if opt.show_progress then xlua.progress(i, #files) end