From cdafbf00ae1f95f81c9db647ef23e08eec0813f5 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Thu, 23 Feb 2017 08:48:00 +0900 Subject: [PATCH] Add BCE(binary cross entropy) loss for segmentation Sigmoid() output is required. --- lib/settings.lua | 2 +- train.lua | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/settings.lua b/lib/settings.lua index 35f271f..fb0b890 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -75,7 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') -cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') +cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)') cmd:option("-update_criterion", "mse", 'mse|loss') local function to_bool(settings, name) diff --git a/train.lua b/train.lua index 6379aeb..0286aee 100644 --- a/train.lua +++ b/train.lua @@ -322,6 +322,10 @@ local function create_criterion(model) return w2nn.L1Criterion():cuda() elseif settings.loss == "mse" then return w2nn.ClippedMSECriterion(0, 1.0):cuda() + elseif settings.loss == "bce" then + local bce = nn.BCECriterion() + bce.sizeAverage = true + return bce:cuda() else error("unsupported loss .." .. settings.loss) end