1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

add support for jaccard in benchmark

This commit is contained in:
nagadomi 2017-02-21 21:39:56 +09:00
parent 385020e0e1
commit d8b7df4505

View file

@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
cmd:option("-x_file", "", 'input image for user method')
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
cmd:option("-border", 0, 'border px that will removed')
cmd:option("-metric", "", '(jaccard)')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -198,8 +199,34 @@ local function remove_border(x, border)
x:size(3) - border,
x:size(2) - border)
end
local function create_metric(metric)
if metric and metric:len() > 0 then
if metric == "jaccard" then
return {
name = "jaccard",
func = function (a, b)
local ga = iproc.rgb2y(a)
local gb = iproc.rgb2y(b)
local ba = torch.Tensor():resizeAs(ga)
local bb = torch.Tensor():resizeAs(gb)
ba:zero()
bb:zero()
ba[torch.gt(ga, 0.5)] = 1.0
bb[torch.gt(gb, 0.5)] = 1.0
local num_a = ba:sum()
local num_b = bb:sum()
local a_and_b = ba:cmul(bb):sum()
return (a_and_b / (num_a + num_b - a_and_b))
end}
else
error("unknown metric: " .. metric)
end
else
return nil
end
end
local function benchmark(opt, x, model1, model2)
local mse1, mse2
local mse1, mse2, am1, am2
local won = {0, 0}
local model1_mse = 0
local model2_mse = 0
@ -212,6 +239,13 @@ local function benchmark(opt, x, model1, model2)
local scale_f = reconstruct.scale
local image_f = reconstruct.image
local detail_fp = nil
local am = nil
local model1_am = 0
local model2_am = 0
if opt.method == "user" or opt.method == "diff" then
am = create_metric(opt.metric)
end
if opt.save_info then
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
end
@ -401,32 +435,57 @@ local function benchmark(opt, x, model1, model2)
ground_truth = remove_border(ground_truth, opt.border)
model1_output = remove_border(model1_output, opt.border)
end
mse1 = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse1
model1_psnr = model1_psnr + MSE2PSNR(mse1)
if am then
am1 = am.func(ground_truth, model1_output)
model1_am = model1_am + am1
else
mse1 = MSE(ground_truth, model1_output, opt.color)
model1_mse = model1_mse + mse1
model1_psnr = model1_psnr + MSE2PSNR(mse1)
end
local won_model = 1
if model2 then
if opt.border > 0 then
model2_output = remove_border(model2_output, opt.border)
end
mse2 = MSE(ground_truth, model2_output, opt.color)
model2_mse = model2_mse + mse2
model2_psnr = model2_psnr + MSE2PSNR(mse2)
if mse1 < mse2 then
won[1] = won[1] + 1
elseif mse1 > mse2 then
won[2] = won[2] + 1
won_model = 2
if am then
am2 = am.func(ground_truth, model2_output)
model2_am = model2_am + am2
else
mse2 = MSE(ground_truth, model2_output, opt.color)
model2_mse = model2_mse + mse2
model2_psnr = model2_psnr + MSE2PSNR(mse2)
end
if am then
if am1 < am2 then
won[1] = won[1] + 1
elseif am1 > am2 then
won[2] = won[2] + 1
won_model = 2
end
else
if mse1 < mse2 then
won[1] = won[1] + 1
elseif mse1 > mse2 then
won[2] = won[2] + 1
won_model = 2
end
end
if detail_fp then
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
if am then
detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
else
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
end
end
else
if detail_fp then
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
if am then
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
else
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
end
end
end
if baseline_output then
@ -450,46 +509,65 @@ local function benchmark(opt, x, model1, model2)
end
end
if opt.show_progress or i == #x then
if model2 then
if baseline_output then
if am then
if model2 then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
i, #x,
model1_time,
model2_time,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
am.name, model1_am / i, am.name, model2_am / i
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
i, #x,
model1_time,
model2_time,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
am.name, model1_am / i
))
end
else
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
baseline_psnr / i, model1_psnr / i
if model2 then
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
i, #x,
model1_time,
model2_time,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i,
won[1], won[2]
))
end
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(model1_mse / i), model1_psnr / i
if baseline_output then
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
baseline_psnr / i, model1_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
i, #x,
model1_time,
math.sqrt(model1_mse / i), model1_psnr / i
))
end
end
end
io.stdout:flush()
@ -510,6 +588,14 @@ local function benchmark(opt, x, model1, model2)
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
end
if model1_am > 0 then
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
math.sqrt(model1_am / #x), model1_time))
end
if model2_am > 0 then
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
math.sqrt(model2_am / #x), model2_time))
end
fp:close()
if detail_fp then
detail_fp:close()