1
0
Fork 0
mirror of synced 2024-06-13 00:14:31 +12:00

show baseline

This commit is contained in:
nagadomi 2015-11-21 00:29:57 +09:00
parent f7b298690d
commit c47df93505

View file

@ -69,12 +69,14 @@ end
local function benchmark(opt, x, input_func, model1, model2)
local model1_mse = 0
local model2_mse = 0
local baseline_mse = 0
local model1_psnr = 0
local model2_psnr = 0
local baseline_psnr = 0
for i = 1, #x do
local ground_truth = x[i]
local input, model1_output, model2_output
local input, model1_output, model2_output, baseline_output
input = input_func(ground_truth, opt)
input = input:float():div(255)
@ -91,6 +93,7 @@ local function benchmark(opt, x, input_func, model1, model2)
if model2 then
model2_output = reconstruct.scale(model2, 2.0, input)
end
baseline_output = iproc.scale(input, input:size(3) * 2, input:size(2) * 2, opt.filter)
end
if opt.color == "y" then
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
@ -99,6 +102,10 @@ local function benchmark(opt, x, input_func, model1, model2)
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
end
if baseline_output then
baseline_mse = baseline_mse + YMSE(ground_truth, baseline_output)
baseline_psnr = baseline_psnr + YPSNR(ground_truth, baseline_output)
end
elseif opt.color == "rgb" then
model1_mse = model1_mse + MSE(ground_truth, model1_output)
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
@ -106,22 +113,46 @@ local function benchmark(opt, x, input_func, model1, model2)
model2_mse = model2_mse + MSE(ground_truth, model2_output)
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
end
if baseline_output then
baseline_mse = baseline_mse + MSE(ground_truth, baseline_output)
baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output)
end
else
error("Unknown color: " .. opt.color)
end
if model2 then
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
model1_mse / i, model2_mse / i,
model1_psnr / i, model2_psnr / i
if baseline_output then
io.stdout:write(
string.format("%d/%d; baseline_mse=%f, model1_mse=%f, model2_mse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
baseline_mse / i,
model1_mse / i, model2_mse / i,
baseline_psnr / i,
model1_psnr / i, model2_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
model1_mse / i, model2_mse / i,
model1_psnr / i, model2_psnr / i
))
end
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
i, #x,
model1_mse / i, model1_psnr / i
if baseline_output then
io.stdout:write(
string.format("%d/%d; baseline_mse=%f, model1_mse=%f, baseline_psnr=%f, model1_psnr=%f \r",
i, #x,
baseline_mse / i, model1_mse / i,
baseline_psnr / i, model1_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
i, #x,
model1_mse / i, model1_psnr / i
))
end
end
io.stdout:flush()
end