more detailed log in benchmark
This commit is contained in:
parent
51e189b6c0
commit
1ce5a8d038
|
@ -193,6 +193,10 @@ local function benchmark(opt, x, model1, model2)
|
||||||
local model2_time = 0
|
local model2_time = 0
|
||||||
local scale_f = reconstruct.scale
|
local scale_f = reconstruct.scale
|
||||||
local image_f = reconstruct.image
|
local image_f = reconstruct.image
|
||||||
|
local detail_fp = nil
|
||||||
|
if opt.save_info then
|
||||||
|
detail_fp = io.open(path.join(opt.output_dir, "benchmark_details.txt"), "w")
|
||||||
|
end
|
||||||
if opt.tta then
|
if opt.tta then
|
||||||
scale_f = function(model, scale, x, block_size, batch_size)
|
scale_f = function(model, scale, x, block_size, batch_size)
|
||||||
return reconstruct.scale_tta(model, opt.tta_level,
|
return reconstruct.scale_tta(model, opt.tta_level,
|
||||||
|
@ -355,6 +359,8 @@ local function benchmark(opt, x, model1, model2)
|
||||||
mse1 = MSE(ground_truth, model1_output, opt.color)
|
mse1 = MSE(ground_truth, model1_output, opt.color)
|
||||||
model1_mse = model1_mse + mse1
|
model1_mse = model1_mse + mse1
|
||||||
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
model1_psnr = model1_psnr + MSE2PSNR(mse1)
|
||||||
|
|
||||||
|
local won_model = 1
|
||||||
if model2 then
|
if model2 then
|
||||||
mse2 = MSE(ground_truth, model2_output, opt.color)
|
mse2 = MSE(ground_truth, model2_output, opt.color)
|
||||||
model2_mse = model2_mse + mse2
|
model2_mse = model2_mse + mse2
|
||||||
|
@ -364,6 +370,15 @@ local function benchmark(opt, x, model1, model2)
|
||||||
won[1] = won[1] + 1
|
won[1] = won[1] + 1
|
||||||
elseif mse1 > mse2 then
|
elseif mse1 > mse2 then
|
||||||
won[2] = won[2] + 1
|
won[2] = won[2] + 1
|
||||||
|
won_model = 2
|
||||||
|
end
|
||||||
|
if detail_fp then
|
||||||
|
detail_fp:write(string.format("%s,%f,%f,%d\n", x[i].basename,
|
||||||
|
MSE2PSNR(mse1), MSE2PSNR(mse2), won_model))
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if detail_fp then
|
||||||
|
detail_fp:write(string.format("%s,%f\n", x[i].basename, MSE2PSNR(mse1)))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
if baseline_output then
|
if baseline_output then
|
||||||
|
@ -447,6 +462,9 @@ local function benchmark(opt, x, model1, model2)
|
||||||
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
math.sqrt(model2_mse / #x), model2_psnr / #x, model2_time))
|
||||||
end
|
end
|
||||||
fp:close()
|
fp:close()
|
||||||
|
if detail_fp then
|
||||||
|
detail_fp:close()
|
||||||
|
end
|
||||||
end
|
end
|
||||||
io.stdout:write("\n")
|
io.stdout:write("\n")
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue