From b111901cbb4b791b423bc34997da37af98b697e1 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 12 Feb 2017 01:56:03 +0900 Subject: [PATCH] add -update_criterion option for back compatible --- lib/settings.lua | 1 + train.lua | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lib/settings.lua b/lib/settings.lua index 4507711..35f271f 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -76,6 +76,7 @@ cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') +cmd:option("-update_criterion", "mse", 'mse|loss') local function to_bool(settings, name) if settings[name] == 1 then diff --git a/train.lua b/train.lua index 77536f5..382bdc1 100644 --- a/train.lua +++ b/train.lua @@ -532,9 +532,15 @@ local function train() if settings.plot then plot(hist_train, hist_valid) end - if score.loss < best_score then + local score_for_update + if settings.update_criterion == "mse" then + score_for_update = score.MSE + else + score_for_update = score.loss + end + if score_for_update < best_score then local test_image = image_loader.load_float(settings.test) -- reload - best_score = score.loss + best_score = score_for_update print("* model has updated") if settings.save_history then torch.save(settings.model_file_best, model:clearState(), "ascii") @@ -583,7 +589,7 @@ local function train() end end end - print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", Minimum loss: " .. best_score .. ", MSE: " .. score.MSE) + print("Batch-wise PSNR: " .. score.PSNR .. ", loss: " .. score.loss .. ", MSE: " .. score.MSE .. ", best: " .. best_score) collectgarbage() end end