diff --git a/lib/settings.lua b/lib/settings.lua index 35f271f..fb0b890 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -75,7 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') cmd:option("-gpu", 1, 'Device ID') -cmd:option("-loss", "huber", 'loss function (huber|l1|mse)') +cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)') cmd:option("-update_criterion", "mse", 'mse|loss') local function to_bool(settings, name) diff --git a/train.lua b/train.lua index 6379aeb..0286aee 100644 --- a/train.lua +++ b/train.lua @@ -322,6 +322,10 @@ local function create_criterion(model) return w2nn.L1Criterion():cuda() elseif settings.loss == "mse" then return w2nn.ClippedMSECriterion(0, 1.0):cuda() + elseif settings.loss == "bce" then + local bce = nn.BCECriterion() + bce.sizeAverage = true + return bce:cuda() else error("unsupported loss .." .. settings.loss) end