From d8b7df450576b5b28b87a55b288a0a6f0ea636d5 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Tue, 21 Feb 2017 21:39:56 +0900 Subject: [PATCH] add support for jaccard in benchmark --- tools/benchmark.lua | 174 +++++++++++++++++++++++++++++++++----------- 1 file changed, 130 insertions(+), 44 deletions(-) diff --git a/tools/benchmark.lua b/tools/benchmark.lua index 7e93ac6..6beb72c 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -47,6 +47,7 @@ cmd:option("-y_dir", "", 'groundtruth image for user method. filename must be th cmd:option("-x_file", "", 'input image for user method') cmd:option("-y_file", "", 'groundtruth image for user method. filename must be the same as x_file') cmd:option("-border", 0, 'border px that will removed') +cmd:option("-metric", "", '(jaccard)') local function to_bool(settings, name) if settings[name] == 1 then @@ -198,8 +199,34 @@ local function remove_border(x, border) x:size(3) - border, x:size(2) - border) end +local function create_metric(metric) + if metric and metric:len() > 0 then + if metric == "jaccard" then + return { + name = "jaccard", + func = function (a, b) + local ga = iproc.rgb2y(a) + local gb = iproc.rgb2y(b) + local ba = torch.Tensor():resizeAs(ga) + local bb = torch.Tensor():resizeAs(gb) + ba:zero() + bb:zero() + ba[torch.gt(ga, 0.5)] = 1.0 + bb[torch.gt(gb, 0.5)] = 1.0 + local num_a = ba:sum() + local num_b = bb:sum() + local a_and_b = ba:cmul(bb):sum() + return (a_and_b / (num_a + num_b - a_and_b)) + end} + else + error("unknown metric: " .. metric) + end + else + return nil + end +end local function benchmark(opt, x, model1, model2) - local mse1, mse2 + local mse1, mse2, am1, am2 local won = {0, 0} local model1_mse = 0 local model2_mse = 0 @@ -212,6 +239,13 @@ local function benchmark(opt, x, model1, model2) local scale_f = reconstruct.scale local image_f = reconstruct.image local detail_fp = nil + local am = nil + local model1_am = 0 + local model2_am = 0 + + if opt.method == "user" or opt.method == "diff" then + am = create_metric(opt.metric) + end if opt.save_info then detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w") end @@ -401,32 +435,57 @@ local function benchmark(opt, x, model1, model2) ground_truth = remove_border(ground_truth, opt.border) model1_output = remove_border(model1_output, opt.border) end - mse1 = MSE(ground_truth, model1_output, opt.color) - model1_mse = model1_mse + mse1 - model1_psnr = model1_psnr + MSE2PSNR(mse1) - + if am then + am1 = am.func(ground_truth, model1_output) + model1_am = model1_am + am1 + else + mse1 = MSE(ground_truth, model1_output, opt.color) + model1_mse = model1_mse + mse1 + model1_psnr = model1_psnr + MSE2PSNR(mse1) + end local won_model = 1 if model2 then if opt.border > 0 then model2_output = remove_border(model2_output, opt.border) end - mse2 = MSE(ground_truth, model2_output, opt.color) - model2_mse = model2_mse + mse2 - model2_psnr = model2_psnr + MSE2PSNR(mse2) - - if mse1 < mse2 then - won[1] = won[1] + 1 - elseif mse1 > mse2 then - won[2] = won[2] + 1 - won_model = 2 + if am then + am2 = am.func(ground_truth, model2_output) + model2_am = model2_am + am2 + else + mse2 = MSE(ground_truth, model2_output, opt.color) + model2_mse = model2_mse + mse2 + model2_psnr = model2_psnr + MSE2PSNR(mse2) + end + if am then + if am1 < am2 then + won[1] = won[1] + 1 + elseif am1 > am2 then + won[2] = won[2] + 1 + won_model = 2 + end + else + if mse1 < mse2 then + won[1] = won[1] + 1 + elseif mse1 > mse2 then + won[2] = won[2] + 1 + won_model = 2 + end end if detail_fp then - detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename, - MSE2PSNR(mse1), MSE2PSNR(mse2), won_model)) + if am then + detail_fp:write(string.format("%s,%f,%d\n", x[i].basename, am1, am2, won_model)) + else + detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename, + MSE2PSNR(mse1), MSE2PSNR(mse2), won_model)) + end end else if detail_fp then - detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1))) + if am then + detail_fp:write(string.format("%s,%f\n", x[i].basename, am1)) + else + detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1))) + end end end if baseline_output then @@ -450,46 +509,65 @@ local function benchmark(opt, x, model1, model2) end end if opt.show_progress or i == #x then - if model2 then - if baseline_output then + if am then + if model2 then io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r", + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_%s=%.3f, model2_%s=%.3f \r", i, #x, model1_time, model2_time, - math.sqrt(baseline_mse / i), - math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), - baseline_psnr / i, - model1_psnr / i, model2_psnr / i, - won[1], won[2] - )) + am.name, model1_am / i, am.name, model2_am / i + )) else io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r", + string.format("%d/%d; model1_time=%.2f, model1_%s=%.3f \r", i, #x, model1_time, - model2_time, - math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), - model1_psnr / i, model2_psnr / i, - won[1], won[2] - )) + am.name, model1_am / i + )) end else - if baseline_output then - io.stdout:write( - string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r", - i, #x, - model1_time, - math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), - baseline_psnr / i, model1_psnr / i + if model2 then + if baseline_output then + io.stdout:write( + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, model2_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_won=%d, model2_won=%d \r", + i, #x, + model1_time, + model2_time, + math.sqrt(baseline_mse / i), + math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), + baseline_psnr / i, + model1_psnr / i, model2_psnr / i, + won[1], won[2] )) + else + io.stdout:write( + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%.3f, model2_rmse=%.3f, model1_psnr=%.3f, model2_psnr=%.3f, model1_own=%d, model2_won=%d \r", + i, #x, + model1_time, + model2_time, + math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), + model1_psnr / i, model2_psnr / i, + won[1], won[2] + )) + end else - io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r", - i, #x, - model1_time, - math.sqrt(model1_mse / i), model1_psnr / i + if baseline_output then + io.stdout:write( + string.format("%d/%d; model1_time=%.2f, baseline_rmse=%.3f, model1_rmse=%.3f, baseline_psnr=%.3f, model1_psnr=%.3f \r", + i, #x, + model1_time, + math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), + baseline_psnr / i, model1_psnr / i )) + else + io.stdout:write( + string.format("%d/%d; model1_time=%.2f, model1_rmse=%.3f, model1_psnr=%.3f \r", + i, #x, + model1_time, + math.sqrt(model1_mse / i), model1_psnr / i + )) + end end end io.stdout:flush() @@ -510,6 +588,14 @@ local function benchmark(opt, x, model1, model2) fp:write(string.format("model2 : RMSE = %.3f, PSNR = %.3f, evaluation time = %.3f\n", math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time)) end + if model1_am > 0 then + fp:write(string.format("model1 : %s = %.3f, evaluation time = %.3f\n", + math.sqrt(model1_am / #x), model1_time)) + end + if model2_am > 0 then + fp:write(string.format("model2 : %s = %.3f, evaluation time = %.3f\n", + math.sqrt(model2_am / #x), model2_time)) + end fp:close() if detail_fp then detail_fp:close()