From 68a6d4cef58d7b068ed21f184e8eece7425362bb Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 17 Apr 2016 02:08:07 +0900 Subject: [PATCH] Use MSE instead of PSNR PSNR depends on the minibatch size and those group. --- lib/minibatch_adam.lua | 3 +-- train.lua | 11 +++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index dade6f7..10a9fec 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -52,8 +52,7 @@ local function minibatch_adam(model, criterion, eval_metric, end end xlua.progress(train_x:size(1), train_x:size(1)) - - return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss} + return { loss = sum_loss / count_loss, MSE = sum_eval / count_loss, PSNR = 10 * math.log10(1 / (sum_eval / count_loss))} end return minibatch_adam diff --git a/train.lua b/train.lua index c4ee633..077e1b0 100644 --- a/train.lua +++ b/train.lua @@ -198,7 +198,7 @@ local function train() return transformer(x, is_validation, n, offset) end local criterion = create_criterion(model) - local eval_metric = w2nn.PSNRCriterion():cuda() + local eval_metric = nn.MSECriterion():cuda() local x = torch.load(settings.images) local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x)) local adam_config = { @@ -212,7 +212,7 @@ local function train() elseif settings.color == "rgb" then ch = 3 end - local best_score = 0.0 + local best_score = 1000.0 print("# make validation-set") local valid_xy = make_validation_set(valid_x, pairwise_func, settings.validation_crops, @@ -227,7 +227,6 @@ local function train() ch, settings.crop_size, settings.crop_size) local y = torch.Tensor(settings.patches * #train_x, ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero() - for epoch = 1, settings.epoch do model:training() print("# " .. epoch) @@ -238,12 +237,12 @@ local function train() model:evaluate() print("# validation") local score = validate(model, eval_metric, valid_xy, adam_config.xBatchSize) - table.insert(hist_train, train_score.PSNR) + table.insert(hist_train, train_score.MSE) table.insert(hist_valid, score) if settings.plot then plot(hist_train, hist_valid) end - if score > best_score then + if score < best_score then local test_image = image_loader.load_float(settings.test) -- reload lrd_count = 0 best_score = score @@ -281,7 +280,7 @@ local function train() lrd_count = 0 end end - print("current: " .. score .. ", best: " .. best_score) + print("PSNR: " .. 10 * math.log10(1 / score) .. ", MSE: " .. score .. ", Best MSE: " .. best_score) collectgarbage() end end