add support for jaccard in benchmark
This commit is contained in:
parent
385020e0e1
commit
d8b7df4505
|
@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
|
||||||
cmd:option("-x_file", "", 'input image for user method')
|
cmd:option("-x_file", "", 'input image for user method')
|
||||||
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
||||||
cmd:option("-border", 0, 'border px that will removed')
|
cmd:option("-border", 0, 'border px that will removed')
|
||||||
|
cmd:option("-metric", "", '(jaccard)')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
@ -198,8 +199,34 @@ local function remove_border(x, border)
|
||||||
x:size(3) - border,
|
x:size(3) - border,
|
||||||
x:size(2) - border)
|
x:size(2) - border)
|
||||||
end
|
end
|
||||||
|
local function create_metric(metric)
|
||||||
|
if metric and metric:len() > 0 then
|
||||||
|
if metric == "jaccard" then
|
||||||
|
return {
|
||||||
|
name = "jaccard",
|
||||||
|
func = function (a, b)
|
||||||
|
local ga = iproc.rgb2y(a)
|
||||||
|
local gb = iproc.rgb2y(b)
|
||||||
|
local ba = torch.Tensor():resizeAs(ga)
|
||||||
|
local bb = torch.Tensor():resizeAs(gb)
|
||||||
|
ba:zero()
|
||||||
|
bb:zero()
|
||||||
|
ba[torch.gt(ga, 0.5)] = 1.0
|
||||||
|
bb[torch.gt(gb, 0.5)] = 1.0
|
||||||
|
local num_a = ba:sum()
|
||||||
|
local num_b = bb:sum()
|
||||||
|
local a_and_b = ba:cmul(bb):sum()
|
||||||
|
return (a_and_b / (num_a + num_b - a_and_b))
|
||||||
|
end}
|
||||||
|
else
|
||||||
|
error("unknown metric: " .. metric)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
end
|
||||||
local function benchmark(opt, x, model1, model2)
|
local function benchmark(opt, x, model1, model2)
|
||||||
local mse1, mse2
|
local mse1, mse2, am1, am2
|
||||||
local won = {0, 0}
|
local won = {0, 0}
|
||||||
local model1_mse = 0
|
local model1_mse = 0
|
||||||
local model2_mse = 0
|
local model2_mse = 0
|
||||||
|
@ -212,6 +239,13 @@ local function benchmark(opt, x, model1, model2)
|
||||||
local scale_f = reconstruct.scale
|
local scale_f = reconstruct.scale
|
||||||
local image_f = reconstruct.image
|
local image_f = reconstruct.image
|
||||||
local detail_fp = nil
|
local detail_fp = nil
|
||||||
|
local am = nil
|
||||||
|
local model1_am = 0
|
||||||
|
local model2_am = 0
|
||||||
|
|
||||||
|
if opt.method == "user" or opt.method == "diff" then
|
||||||
|
am = create_metric(opt.metric)
|
||||||
|
end
|
||||||
if opt.save_info then
|
if opt.save_info then
|
||||||
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
||||||
end
|
end
|
||||||
|
@ -401,32 +435,57 @@ local function benchmark(opt, x, model1, model2)
|
||||||
ground_truth = remove_border(ground_truth, opt.border)
|
ground_truth = remove_border(ground_truth, opt.border)
|
||||||
model1_output = remove_border(model1_output, opt.border)
|
model1_output = remove_border(model1_output, opt.border)
|
||||||
end
|
end
|
||||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
if am then
|
||||||
model1_mse = model1_mse + mse1
|
am1 = am.func(ground_truth, model1_output)
|
||||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
model1_am = model1_am + am1
|
||||||
|
else
|
||||||
|
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||||
|
model1_mse = model1_mse + mse1
|
||||||
|
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||||
|
end
|
||||||
local won_model = 1
|
local won_model = 1
|
||||||
if model2 then
|
if model2 then
|
||||||
if opt.border > 0 then
|
if opt.border > 0 then
|
||||||
model2_output = remove_border(model2_output, opt.border)
|
model2_output = remove_border(model2_output, opt.border)
|
||||||
end
|
end
|
||||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
if am then
|
||||||
model2_mse = model2_mse + mse2
|
am2 = am.func(ground_truth, model2_output)
|
||||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
model2_am = model2_am + am2
|
||||||
|
else
|
||||||
if mse1 < mse2 then
|
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||||
won[1] = won[1] + 1
|
model2_mse = model2_mse + mse2
|
||||||
elseif mse1 > mse2 then
|
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||||
won[2] = won[2] + 1
|
end
|
||||||
won_model = 2
|
if am then
|
||||||
|
if am1 < am2 then
|
||||||
|
won[1] = won[1] + 1
|
||||||
|
elseif am1 > am2 then
|
||||||
|
won[2] = won[2] + 1
|
||||||
|
won_model = 2
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if mse1 < mse2 then
|
||||||
|
won[1] = won[1] + 1
|
||||||
|
elseif mse1 > mse2 then
|
||||||
|
won[2] = won[2] + 1
|
||||||
|
won_model = 2
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if detail_fp then
|
if detail_fp then
|
||||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
if am then
|
||||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
|
||||||
|
else
|
||||||
|
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||||
|
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
if detail_fp then
|
if detail_fp then
|
||||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
if am then
|
||||||
|
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
|
||||||
|
else
|
||||||
|
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if baseline_output then
|
if baseline_output then
|
||||||
|
@ -450,46 +509,65 @@ local function benchmark(opt, x, model1, model2)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if opt.show_progress or i == #x then
|
if opt.show_progress or i == #x then
|
||||||
if model2 then
|
if am then
|
||||||
if baseline_output then
|
if model2 then
|
||||||
io.stdout:write(
|
io.stdout:write(
|
||||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
|
||||||
i, #x,
|
i, #x,
|
||||||
model1_time,
|
model1_time,
|
||||||
model2_time,
|
model2_time,
|
||||||
math.sqrt(baseline_mse / i),
|
am.name, model1_am / i, am.name, model2_am / i
|
||||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
))
|
||||||
baseline_psnr / i,
|
|
||||||
model1_psnr / i, model2_psnr / i,
|
|
||||||
won[1], won[2]
|
|
||||||
))
|
|
||||||
else
|
else
|
||||||
io.stdout:write(
|
io.stdout:write(
|
||||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
|
||||||
i, #x,
|
i, #x,
|
||||||
model1_time,
|
model1_time,
|
||||||
model2_time,
|
am.name, model1_am / i
|
||||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
))
|
||||||
model1_psnr / i, model2_psnr / i,
|
|
||||||
won[1], won[2]
|
|
||||||
))
|
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
if baseline_output then
|
if model2 then
|
||||||
io.stdout:write(
|
if baseline_output then
|
||||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
io.stdout:write(
|
||||||
i, #x,
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||||
model1_time,
|
i, #x,
|
||||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
model1_time,
|
||||||
baseline_psnr / i, model1_psnr / i
|
model2_time,
|
||||||
|
math.sqrt(baseline_mse / i),
|
||||||
|
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||||
|
baseline_psnr / i,
|
||||||
|
model1_psnr / i, model2_psnr / i,
|
||||||
|
won[1], won[2]
|
||||||
))
|
))
|
||||||
|
else
|
||||||
|
io.stdout:write(
|
||||||
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||||
|
i, #x,
|
||||||
|
model1_time,
|
||||||
|
model2_time,
|
||||||
|
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||||
|
model1_psnr / i, model2_psnr / i,
|
||||||
|
won[1], won[2]
|
||||||
|
))
|
||||||
|
end
|
||||||
else
|
else
|
||||||
io.stdout:write(
|
if baseline_output then
|
||||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
io.stdout:write(
|
||||||
i, #x,
|
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||||
model1_time,
|
i, #x,
|
||||||
math.sqrt(model1_mse / i), model1_psnr / i
|
model1_time,
|
||||||
|
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||||
|
baseline_psnr / i, model1_psnr / i
|
||||||
))
|
))
|
||||||
|
else
|
||||||
|
io.stdout:write(
|
||||||
|
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||||
|
i, #x,
|
||||||
|
model1_time,
|
||||||
|
math.sqrt(model1_mse / i), model1_psnr / i
|
||||||
|
))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
io.stdout:flush()
|
io.stdout:flush()
|
||||||
|
@ -510,6 +588,14 @@ local function benchmark(opt, x, model1, model2)
|
||||||
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
||||||
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
||||||
end
|
end
|
||||||
|
if model1_am > 0 then
|
||||||
|
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
|
||||||
|
math.sqrt(model1_am / #x), model1_time))
|
||||||
|
end
|
||||||
|
if model2_am > 0 then
|
||||||
|
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
|
||||||
|
math.sqrt(model2_am / #x), model2_time))
|
||||||
|
end
|
||||||
fp:close()
|
fp:close()
|
||||||
if detail_fp then
|
if detail_fp then
|
||||||
detail_fp:close()
|
detail_fp:close()
|
||||||
|
|
Loading…
Reference in a new issue