From 0eccbc65555ababf1529a70032a0cf009f7841c4 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 9 Jan 2017 13:00:00 +0900 Subject: [PATCH] Add support for MSE loss --- lib/settings.lua | 2 +- train.lua | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/settings.lua b/lib/settings.lua index ee3dcb3..4507711 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -75,7 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') -cmd:option("-loss", "huber", 'loss function (huber|l1)') +cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') local function to_bool(settings, name) if settings[name] == 1 then diff --git a/train.lua b/train.lua index fe20f89..77536f5 100644 --- a/train.lua +++ b/train.lua @@ -317,6 +317,8 @@ local function create_criterion(model) end elseif settings.loss == "l1" then return w2nn.L1Criterion():cuda() + elseif settings.loss == "mse" then + return w2nn.ClippedMSECriterion(0, 1.0):cuda() else error("unsupported loss .." .. settings.loss) end