1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add aux_lbp; Fix oracle

This commit is contained in:
nagadomi 2018-10-16 13:02:09 +00:00
parent 8065ec9bb2
commit 89c8f5db8e
2 changed files with 25 additions and 6 deletions

View file

@ -46,12 +46,25 @@ local function minibatch_adam(model, criterion, eval_metric,
local f = criterion:forward(output, targets)
local se = 0
if config.xInstanceLoss then
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
if type(output) then
local tbl = {}
for i = 1, batch_size do
for j = 1, #output do
tbl[j] = output[j][i]
end
local el = eval_metric:forward(tbl, targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
else
for i = 1, batch_size do
local el = eval_metric:forward(output[i], targets[i])
se = se + el
instance_loss[shuffle[t + i - 1]] = el
end
se = (se / batch_size)
end
else
se = eval_metric:forward(output, targets)
end

View file

@ -394,6 +394,12 @@ local function create_criterion(model)
else
return w2nn.RandomBinaryCriterion(1, 512):cuda()
end
elseif settings.loss == "aux_lbp" then
if reconstruct.is_rgb(model) then
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 512}):cuda()
else
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 512}):cuda()
end
else
error("unsupported loss .." .. settings.loss)
end