Fix division by zero error in validate()
This commit is contained in:
parent
d8b7df4505
commit
a121eb39cb
|
@ -290,7 +290,7 @@ local function validate(model, criterion, eval_metric, data, batch_size)
|
||||||
local batch_mse = eval_metric:forward(z, targets)
|
local batch_mse = eval_metric:forward(z, targets)
|
||||||
loss = loss + criterion:forward(z, targets)
|
loss = loss + criterion:forward(z, targets)
|
||||||
mse = mse + batch_mse
|
mse = mse + batch_mse
|
||||||
psnr = psnr + (10 * math.log10(1 / batch_mse))
|
psnr = psnr + (10 * math.log10(1 / (batch_mse + 1.0e-6))
|
||||||
loss_count = loss_count + 1
|
loss_count = loss_count + 1
|
||||||
if loss_count % 10 == 0 then
|
if loss_count % 10 == 0 then
|
||||||
xlua.progress(t, #data)
|
xlua.progress(t, #data)
|
||||||
|
|
Loading…
Reference in a new issue