1
0
Fork 0
mirror of synced 2024-06-01 10:39:30 +12:00

Add aux_lbp; Fix oracle

This commit is contained in:
nagadomi 2018-10-16 13:02:09 +00:00
parent 8065ec9bb2
commit 89c8f5db8e
2 changed files with 25 additions and 6 deletions

View file

@ -46,12 +46,25 @@ local function minibatch_adam(model, criterion, eval_metric,
local f = criterion:forward(output, targets) local f = criterion:forward(output, targets)
local se = 0 local se = 0
if config.xInstanceLoss then if config.xInstanceLoss then
for i = 1, batch_size do if type(output) then
local el = eval_metric:forward(output[i], targets[i]) local tbl = {}
se = se + el for i = 1, batch_size do
instance_loss[shuffle[t + i - 1]] = el for j = 1, #output do
end tbl[j] = output[j][i]
se = (se / batch_size) end
local el = eval_metric:forward(tbl, targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
else
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
end
else else
se = eval_metric:forward(output, targets) se = eval_metric:forward(output, targets)
end end

View file

@ -394,6 +394,12 @@ local function create_criterion(model)
else else
return w2nn.RandomBinaryCriterion(1, 512):cuda() return w2nn.RandomBinaryCriterion(1, 512):cuda()
end end
elseif settings.loss == "aux_lbp" then
if reconstruct.is_rgb(model) then
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 512}):cuda()
else
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 512}):cuda()
end
else else
error("unsupported loss .." .. settings.loss) error("unsupported loss .." .. settings.loss)
end end