Add BCE(binary cross entropy) loss for segmentation
Sigmoid() output is required.
This commit is contained in:
parent
dac1b89750
commit
cdafbf00ae
|
@ -75,7 +75,7 @@ cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate *
|
||||||
cmd:option("-resume", "", 'resume model file')
|
cmd:option("-resume", "", 'resume model file')
|
||||||
cmd:option("-name", "user", 'model name for user method')
|
cmd:option("-name", "user", 'model name for user method')
|
||||||
cmd:option("-gpu", 1, 'Device ID')
|
cmd:option("-gpu", 1, 'Device ID')
|
||||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse)')
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||||
cmd:option("-update_criterion", "mse", 'mse|loss')
|
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
|
|
|
@ -322,6 +322,10 @@ local function create_criterion(model)
|
||||||
return w2nn.L1Criterion():cuda()
|
return w2nn.L1Criterion():cuda()
|
||||||
elseif settings.loss == "mse" then
|
elseif settings.loss == "mse" then
|
||||||
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
|
return w2nn.ClippedMSECriterion(0, 1.0):cuda()
|
||||||
|
elseif settings.loss == "bce" then
|
||||||
|
local bce = nn.BCECriterion()
|
||||||
|
bce.sizeAverage = true
|
||||||
|
return bce:cuda()
|
||||||
else
|
else
|
||||||
error("unsupported loss .." .. settings.loss)
|
error("unsupported loss .." .. settings.loss)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue