fix batchwise psnr
This commit is contained in:
parent
b111901cbb
commit
29260ede24
|
@ -11,7 +11,7 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
config.xEvalCount = 0
|
||||
config.learningRate = config.xLearningRate
|
||||
end
|
||||
|
||||
local sum_psnr = 0
|
||||
local sum_loss = 0
|
||||
local sum_eval = 0
|
||||
local count_loss = 0
|
||||
|
@ -55,6 +55,7 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
else
|
||||
se = eval_metric:forward(output, targets)
|
||||
end
|
||||
sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
|
||||
sum_eval = sum_eval + se
|
||||
sum_loss = sum_loss + f
|
||||
count_loss = count_loss + 1
|
||||
|
@ -69,10 +70,9 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
collectgarbage()
|
||||
xlua.progress(t, train_x:size(1))
|
||||
end
|
||||
|
||||
end
|
||||
xlua.progress(train_x:size(1), train_x:size(1))
|
||||
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss
|
||||
return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss
|
||||
end
|
||||
|
||||
return minibatch_adam
|
||||
|
|
|
@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches)
|
|||
return data
|
||||
end
|
||||
local function validate(model, criterion, eval_metric, data, batch_size)
|
||||
local psnr = 0
|
||||
local loss = 0
|
||||
local mse = 0
|
||||
local loss_count = 0
|
||||
|
@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|||
inputs:copy(inputs_tmp)
|
||||
targets:copy(targets_tmp)
|
||||
local z = model:forward(inputs)
|
||||
local batch_mse = eval_metric:forward(z, targets)
|
||||
loss = loss + criterion:forward(z, targets)
|
||||
mse = mse + eval_metric:forward(z, targets)
|
||||
mse = mse + batch_mse
|
||||
psnr = psnr + (10 * math.log10(1 / batch_mse))
|
||||
loss_count = loss_count + 1
|
||||
if loss_count % 10 == 0 then
|
||||
xlua.progress(t, #data)
|
||||
|
@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
|||
end
|
||||
end
|
||||
xlua.progress(#data, #data)
|
||||
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))}
|
||||
return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count}
|
||||
end
|
||||
|
||||
local function create_criterion(model)
|
||||
|
|
Loading…
Reference in a new issue