Add count of won in benchmark
This commit is contained in:
parent
d7ab10581c
commit
51e189b6c0
|
@ -181,7 +181,8 @@ local function transform_scale_jpeg(x, opt)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function benchmark(opt, x, model1, model2)
|
local function benchmark(opt, x, model1, model2)
|
||||||
local mse
|
local mse1, mse2
|
||||||
|
local won = {0, 0}
|
||||||
local model1_mse = 0
|
local model1_mse = 0
|
||||||
local model2_mse = 0
|
local model2_mse = 0
|
||||||
local baseline_mse = 0
|
local baseline_mse = 0
|
||||||
|
@ -351,13 +352,19 @@ local function benchmark(opt, x, model1, model2)
|
||||||
ground_truth = x[i].y
|
ground_truth = x[i].y
|
||||||
model1_output = input
|
model1_output = input
|
||||||
end
|
end
|
||||||
mse = MSE(ground_truth, model1_output, opt.color)
|
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||||
model1_mse = model1_mse + mse
|
model1_mse = model1_mse + mse1
|
||||||
model1_psnr = model1_psnr + MSE2PSNR(mse)
|
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||||
if model2 then
|
if model2 then
|
||||||
mse = MSE(ground_truth, model2_output, opt.color)
|
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||||
model2_mse = model2_mse + mse
|
model2_mse = model2_mse + mse2
|
||||||
model2_psnr = model2_psnr + MSE2PSNR(mse)
|
model2_psnr = model2_psnr + MSE2PSNR(mse2)
|
||||||
|
|
||||||
|
if mse1 < mse2 then
|
||||||
|
won[1] = won[1] + 1
|
||||||
|
elseif mse1 > mse2 then
|
||||||
|
won[2] = won[2] + 1
|
||||||
|
end
|
||||||
end
|
end
|
||||||
if baseline_output then
|
if baseline_output then
|
||||||
mse = MSE(ground_truth, baseline_output, opt.color)
|
mse = MSE(ground_truth, baseline_output, opt.color)
|
||||||
|
@ -382,23 +389,25 @@ local function benchmark(opt, x, model1, model2)
|
||||||
if model2 then
|
if model2 then
|
||||||
if baseline_output then
|
if baseline_output then
|
||||||
io.stdout:write(
|
io.stdout:write(
|
||||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f, model1_won=%d, model2_won=%d \r",
|
||||||
i, #x,
|
i, #x,
|
||||||
model1_time,
|
model1_time,
|
||||||
model2_time,
|
model2_time,
|
||||||
math.sqrt(baseline_mse / i),
|
math.sqrt(baseline_mse / i),
|
||||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||||
baseline_psnr / i,
|
baseline_psnr / i,
|
||||||
model1_psnr / i, model2_psnr / i
|
model1_psnr / i, model2_psnr / i,
|
||||||
|
won[1], won[2]
|
||||||
))
|
))
|
||||||
else
|
else
|
||||||
io.stdout:write(
|
io.stdout:write(
|
||||||
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
string.format("%d/%d; model1_time=%.2f, model2_time=%.2f, model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f, model1_own=%d, model2_won=%d \r",
|
||||||
i, #x,
|
i, #x,
|
||||||
model1_time,
|
model1_time,
|
||||||
model2_time,
|
model2_time,
|
||||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||||
model1_psnr / i, model2_psnr / i
|
model1_psnr / i, model2_psnr / i,
|
||||||
|
won[1], won[2]
|
||||||
))
|
))
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
|
|
Loading…
Reference in a new issue