diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 96f189f..a6f7abb 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -11,7 +11,7 @@ local function minibatch_adam(model, criterion, eval_metric, config.xEvalCount = 0 config.learningRate = config.xLearningRate end - + local sum_psnr = 0 local sum_loss = 0 local sum_eval = 0 local count_loss = 0 @@ -55,6 +55,7 @@ local function minibatch_adam(model, criterion, eval_metric, else se = eval_metric:forward(output, targets) end + sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6))) sum_eval = sum_eval + se sum_loss = sum_loss + f count_loss = count_loss + 1 @@ -69,10 +70,9 @@ local function minibatch_adam(model, criterion, eval_metric, collectgarbage() xlua.progress(t, train_x:size(1)) end - end xlua.progress(train_x:size(1), train_x:size(1)) - return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))}, instance_loss + return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = sum_psnr / count_loss}, instance_loss end return minibatch_adam diff --git a/lib/settings.lua b/lib/settings.lua index 4507711..35f271f 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') +cmd:option("-update_criterion", "mse", 'mse|loss') local function to_bool(settings, name) if settings[name] == 1 then diff --git a/train.lua b/train.lua index be215f0..3f2eecb 100644 --- a/train.lua +++ b/train.lua @@ -262,6 +262,7 @@ local function make_validation_set(x, n, patches) return data end local function validate(model, criterion, eval_metric, data, batch_size) + local psnr = 0 local loss = 0 local mse = 0 local loss_count = 0 @@ -286,8 +287,10 @@ local function validate(model, criterion, eval_metric, data, batch_size) inputs:copy(inputs_tmp) targets:copy(targets_tmp) local z = model:forward(inputs) + local batch_mse = eval_metric:forward(z, targets) loss = loss + criterion:forward(z, targets) - mse = mse + eval_metric:forward(z, targets) + mse = mse + batch_mse + psnr = psnr + (10 * math.log10(1 / batch_mse)) loss_count = loss_count + 1 if loss_count % 10 == 0 then xlua.progress(t, #data) @@ -295,7 +298,7 @@ local function validate(model, criterion, eval_metric, data, batch_size) end end xlua.progress(#data, #data) - return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = 10 * math.log10(1 / (mse / loss_count))} + return {loss = loss / loss_count, MSE = mse / loss_count, PSNR = psnr / loss_count} end local function create_criterion(model) @@ -540,9 +543,15 @@ local function train() if settings.plot then plot(hist_train, hist_valid) end - if score.loss < best_score then + local score_for_update + if settings.update_criterion == "mse" then + score_for_update = score.MSE + else + score_for_update = score.loss + end + if score_for_update < best_score then local test_image = image_loader.load_float(settings.test) -- reload - best_score = score.loss + best_score = score_for_update print("* model has updated") if settings.save_history then torch.save(settings.model_file_best, model:clearState(), "ascii") @@ -591,7 +600,7 @@ local function train() end end end - print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE) + print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score) collectgarbage() end end