add support for jaccard in benchmark
This commit is contained in:
parent
385020e0e1
commit
d8b7df4505
|
@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th
|
|||
cmd:option("-x_file", "", 'input image for user method')
|
||||
cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file')
|
||||
cmd:option("-border", 0, 'border px that will removed')
|
||||
cmd:option("-metric", "", '(jaccard)')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -198,8 +199,34 @@ local function remove_border(x, border)
|
|||
x:size(3) - border,
|
||||
x:size(2) - border)
|
||||
end
|
||||
local function create_metric(metric)
|
||||
if metric and metric:len() > 0 then
|
||||
if metric == "jaccard" then
|
||||
return {
|
||||
name = "jaccard",
|
||||
func = function (a, b)
|
||||
local ga = iproc.rgb2y(a)
|
||||
local gb = iproc.rgb2y(b)
|
||||
local ba = torch.Tensor():resizeAs(ga)
|
||||
local bb = torch.Tensor():resizeAs(gb)
|
||||
ba:zero()
|
||||
bb:zero()
|
||||
ba[torch.gt(ga, 0.5)] = 1.0
|
||||
bb[torch.gt(gb, 0.5)] = 1.0
|
||||
local num_a = ba:sum()
|
||||
local num_b = bb:sum()
|
||||
local a_and_b = ba:cmul(bb):sum()
|
||||
return (a_and_b / (num_a + num_b - a_and_b))
|
||||
end}
|
||||
else
|
||||
error("unknown metric: " .. metric)
|
||||
end
|
||||
else
|
||||
return nil
|
||||
end
|
||||
end
|
||||
local function benchmark(opt, x, model1, model2)
|
||||
local mse1, mse2
|
||||
local mse1, mse2, am1, am2
|
||||
local won = {0, 0}
|
||||
local model1_mse = 0
|
||||
local model2_mse = 0
|
||||
|
@ -212,6 +239,13 @@ local function benchmark(opt, x, model1, model2)
|
|||
local scale_f = reconstruct.scale
|
||||
local image_f = reconstruct.image
|
||||
local detail_fp = nil
|
||||
local am = nil
|
||||
local model1_am = 0
|
||||
local model2_am = 0
|
||||
|
||||
if opt.method == "user" or opt.method == "diff" then
|
||||
am = create_metric(opt.metric)
|
||||
end
|
||||
if opt.save_info then
|
||||
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
||||
end
|
||||
|
@ -401,32 +435,57 @@ local function benchmark(opt, x, model1, model2)
|
|||
ground_truth = remove_border(ground_truth, opt.border)
|
||||
model1_output = remove_border(model1_output, opt.border)
|
||||
end
|
||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||
model1_mse = model1_mse + mse1
|
||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||
|
||||
if am then
|
||||
am1 = am.func(ground_truth, model1_output)
|
||||
model1_am = model1_am + am1
|
||||
else
|
||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||
model1_mse = model1_mse + mse1
|
||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||
end
|
||||
local won_model = 1
|
||||
if model2 then
|
||||
if opt.border > 0 then
|
||||
model2_output = remove_border(model2_output, opt.border)
|
||||
end
|
||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||
model2_mse = model2_mse + mse2
|
||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||
|
||||
if mse1 < mse2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif mse1 > mse2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
if am then
|
||||
am2 = am.func(ground_truth, model2_output)
|
||||
model2_am = model2_am + am2
|
||||
else
|
||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||
model2_mse = model2_mse + mse2
|
||||
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||
end
|
||||
if am then
|
||||
if am1 < am2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif am1 > am2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
end
|
||||
else
|
||||
if mse1 < mse2 then
|
||||
won[1] = won[1] + 1
|
||||
elseif mse1 > mse2 then
|
||||
won[2] = won[2] + 1
|
||||
won_model = 2
|
||||
end
|
||||
end
|
||||
if detail_fp then
|
||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||
if am then
|
||||
detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model))
|
||||
else
|
||||
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||
end
|
||||
end
|
||||
else
|
||||
if detail_fp then
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||
if am then
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, am1))
|
||||
else
|
||||
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||
end
|
||||
end
|
||||
end
|
||||
if baseline_output then
|
||||
|
@ -450,46 +509,65 @@ local function benchmark(opt, x, model1, model2)
|
|||
end
|
||||
end
|
||||
if opt.show_progress or i == #x then
|
||||
if model2 then
|
||||
if baseline_output then
|
||||
if am then
|
||||
if model2 then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(baseline_mse / i),
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
baseline_psnr / i,
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
am.name, model1_am / i, am.name, model2_am / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||
string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
am.name, model1_am / i
|
||||
))
|
||||
end
|
||||
else
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||
baseline_psnr / i, model1_psnr / i
|
||||
if model2 then
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(baseline_mse / i),
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
baseline_psnr / i,
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
model2_time,
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
model1_psnr / i, model2_psnr / i,
|
||||
won[1], won[2]
|
||||
))
|
||||
end
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(model1_mse / i), model1_psnr / i
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||
baseline_psnr / i, model1_psnr / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r",
|
||||
i, #x,
|
||||
model1_time,
|
||||
math.sqrt(model1_mse / i), model1_psnr / i
|
||||
))
|
||||
end
|
||||
end
|
||||
end
|
||||
io.stdout:flush()
|
||||
|
@ -510,6 +588,14 @@ local function benchmark(opt, x, model1, model2)
|
|||
fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n",
|
||||
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
||||
end
|
||||
if model1_am > 0 then
|
||||
fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n",
|
||||
math.sqrt(model1_am / #x), model1_time))
|
||||
end
|
||||
if model2_am > 0 then
|
||||
fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n",
|
||||
math.sqrt(model2_am / #x), model2_time))
|
||||
end
|
||||
fp:close()
|
||||
if detail_fp then
|
||||
detail_fp:close()
|
||||
|
|
Loading…
Reference in a new issue