diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 96f189f..a6f7abb 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -11,7 +11,7 @@ local function minibatch_adam(model, criterion, eval_metric, config.xEvalCount = 0 config.learningRate = config.xLearningRate end - + local sum_psnr = 0 local sum_loss = 0 local sum_eval = 0 local count_loss = 0 @@ -55,6 +55,7 @@ local function minibatch_adam(model, criterion, eval_metric, else se = eval_metric:forward(output, targets) end + sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6))) sum_eval = sum_eval + se sum_loss = sum_loss + f count_loss = count_loss + 1 @@ -69,10 +70,9 @@ local function minibatch_adam(model, criterion, eval_metric, collectgarbage() xlua.progress(t, train_x:size(1)) end - end xlua.progress(train_x:size(1), train_x:size(1)) - return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss + return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss end return minibatch_adam diff --git a/train.lua b/train.lua index 382bdc1..f58280d 100644 --- a/train.lua +++ b/train.lua @@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches) return data end local function validate(model, criterion, eval_metric, data, batch_size) + local psnr = 0 local loss = 0 local mse = 0 local loss_count = 0 @@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size) inputs:copy(inputs_tmp) targets:copy(targets_tmp) local z = model:forward(inputs) + local batch_mse = eval_metric:forward(z, targets) loss = loss + criterion:forward(z, targets) - mse = mse + eval_metric:forward(z, targets) + mse = mse + batch_mse + psnr = psnr + (10 * math.log10(1 / batch_mse)) loss_count = loss_count + 1 if loss_count % 10 == 0 then xlua.progress(t, #data) @@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size) end end xlua.progress(#data, #data) - return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))} + return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count} end local function create_criterion(model)