diff --git a/tools/benchmark.lua b/tools/benchmark.lua index a9aa70f..7d5b7b9 100644 --- a/tools/benchmark.lua +++ b/tools/benchmark.lua @@ -181,7 +181,8 @@ local function transform_scale_jpeg(x, opt) end local function benchmark(opt, x, model1, model2) - local mse + local mse1, mse2 + local won = {0, 0} local model1_mse = 0 local model2_mse = 0 local baseline_mse = 0 @@ -351,13 +352,19 @@ local function benchmark(opt, x, model1, model2) ground_truth = x[i].y model1_output = input end - mse = MSE(ground_truth, model1_output, opt.color) - model1_mse = model1_mse + mse - model1_psnr = model1_psnr + MSE2PSNR(mse) + mse1 = MSE(ground_truth, model1_output, opt.color) + model1_mse = model1_mse + mse1 + model1_psnr = model1_psnr + MSE2PSNR(mse1) if model2 then - mse = MSE(ground_truth, model2_output, opt.color) - model2_mse = model2_mse + mse - model2_psnr = model2_psnr + MSE2PSNR(mse) + mse2 = MSE(ground_truth, model2_output, opt.color) + model2_mse = model2_mse + mse2 + model2_psnr = model2_psnr + MSE2PSNR(mse2) + + if mse1 < mse2 then + won[1] = won[1] + 1 + elseif mse1 > mse2 then + won[2] = won[2] + 1 + end end if baseline_output then mse = MSE(ground_truth, baseline_output, opt.color) @@ -382,23 +389,25 @@ local function benchmark(opt, x, model1, model2) if model2 then if baseline_output then io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r", + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f, model1_won=%d, model2_won=%d \r", i, #x, model1_time, model2_time, math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), baseline_psnr / i, - model1_psnr / i, model2_psnr / i + model1_psnr / i, model2_psnr / i, + won[1], won[2] )) else io.stdout:write( - string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r", + string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f, model1_own=%d, model2_won=%d \r", i, #x, model1_time, model2_time, math.sqrt(model1_mse / i), math.sqrt(model2_mse / i), - model1_psnr / i, model2_psnr / i + model1_psnr / i, model2_psnr / i, + won[1], won[2] )) end else