Improve benchmark script
This commit is contained in:
parent
1407973b88
commit
d31fbe9bb1
|
@ -7,6 +7,7 @@ local iproc = require 'iproc'
|
||||||
local reconstruct = require 'reconstruct'
|
local reconstruct = require 'reconstruct'
|
||||||
local image_loader = require 'image_loader'
|
local image_loader = require 'image_loader'
|
||||||
local gm = require 'graphicsmagick'
|
local gm = require 'graphicsmagick'
|
||||||
|
local cjson = require 'cjson'
|
||||||
|
|
||||||
local cmd = torch.CmdLine()
|
local cmd = torch.CmdLine()
|
||||||
cmd:text()
|
cmd:text()
|
||||||
|
@ -30,6 +31,8 @@ cmd:option("-save_baseline_image", 0, 'save baseline images')
|
||||||
cmd:option("-output_dir", "./", 'output directroy')
|
cmd:option("-output_dir", "./", 'output directroy')
|
||||||
cmd:option("-show_progress", 1, 'show progressbar')
|
cmd:option("-show_progress", 1, 'show progressbar')
|
||||||
cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Catrom(Bicubic))')
|
cmd:option("-baseline_filter", "Catrom", 'baseline interpolation (Box|Lanczos|Catrom(Bicubic))')
|
||||||
|
cmd:option("-save_info", 0, 'save score and parameters to benchmark.txt')
|
||||||
|
cmd:option("-save_all", 0, 'group -save_info, -save_image and -save_baseline_image option')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
@ -45,8 +48,16 @@ if cudnn then
|
||||||
cudnn.benchmark = false
|
cudnn.benchmark = false
|
||||||
end
|
end
|
||||||
to_bool(opt, "gamma_correction")
|
to_bool(opt, "gamma_correction")
|
||||||
to_bool(opt, "save_image")
|
to_bool(opt, "save_all")
|
||||||
to_bool(opt, "save_baseline_image")
|
if opt.save_all then
|
||||||
|
opt.save_image = true
|
||||||
|
opt.save_info = true
|
||||||
|
opt.save_baseline_image = true
|
||||||
|
else
|
||||||
|
to_bool(opt, "save_image")
|
||||||
|
to_bool(opt, "save_info")
|
||||||
|
to_bool(opt, "save_baseline_image")
|
||||||
|
end
|
||||||
to_bool(opt, "show_progress")
|
to_bool(opt, "show_progress")
|
||||||
|
|
||||||
local function rgb2y_matlab(x)
|
local function rgb2y_matlab(x)
|
||||||
|
@ -207,6 +218,23 @@ local function benchmark(opt, x, input_func, model1, model2)
|
||||||
io.stdout:flush()
|
io.stdout:flush()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
if opt.save_info then
|
||||||
|
local fp = io.open(path.join(opt.output_dir, "benchmark.txt"), "w")
|
||||||
|
fp:write("options : " .. cjson.encode(opt) .. "\n")
|
||||||
|
if baseline_psnr > 0 then
|
||||||
|
fp:write(string.format("baseline: RMSE = %.3f, PSNR = %.3f\n",
|
||||||
|
math.sqrt(baseline_mse / #x), baseline_psnr / #x))
|
||||||
|
end
|
||||||
|
if model1_psnr > 0 then
|
||||||
|
fp:write(string.format("model1 : RMSE = %.3f, PSNR = %.3f\n",
|
||||||
|
math.sqrt(model1_mse / #x), model1_psnr / #x))
|
||||||
|
end
|
||||||
|
if model2_psnr > 0 then
|
||||||
|
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f\n",
|
||||||
|
math.sqrt(model2_mse / #x), model2_psnr / #x))
|
||||||
|
end
|
||||||
|
fp:close()
|
||||||
|
end
|
||||||
io.stdout:write("\n")
|
io.stdout:write("\n")
|
||||||
end
|
end
|
||||||
local function load_data(test_dir)
|
local function load_data(test_dir)
|
||||||
|
@ -216,8 +244,11 @@ local function load_data(test_dir)
|
||||||
local name = path.basename(files[i])
|
local name = path.basename(files[i])
|
||||||
local e = path.extension(name)
|
local e = path.extension(name)
|
||||||
local base = name:sub(0, name:len() - e:len())
|
local base = name:sub(0, name:len() - e:len())
|
||||||
table.insert(test_x, {image = iproc.crop_mod4(image_loader.load_float(files[i])),
|
local img = image_loader.load_float(files[i])
|
||||||
basename = base})
|
if img then
|
||||||
|
table.insert(test_x, {image = iproc.crop_mod4(img),
|
||||||
|
basename = base})
|
||||||
|
end
|
||||||
if opt.show_progress then
|
if opt.show_progress then
|
||||||
xlua.progress(i, #files)
|
xlua.progress(i, #files)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue