Add aux_lbp; Fix oracle
This commit is contained in:
parent
8065ec9bb2
commit
89c8f5db8e
|
@ -46,12 +46,25 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
local f = criterion:forward(output, targets)
|
||||
local se = 0
|
||||
if config.xInstanceLoss then
|
||||
if type(output) then
|
||||
local tbl = {}
|
||||
for i = 1, batch_size do
|
||||
for j = 1, #output do
|
||||
tbl[j] = output[j][i]
|
||||
end
|
||||
local el = eval_metric:forward(tbl, targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
end
|
||||
se = (se / batch_size)
|
||||
else
|
||||
for i = 1, batch_size do
|
||||
local el = eval_metric:forward(output[i], targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
end
|
||||
se = (se / batch_size)
|
||||
end
|
||||
else
|
||||
se = eval_metric:forward(output, targets)
|
||||
end
|
||||
|
|
|
@ -394,6 +394,12 @@ local function create_criterion(model)
|
|||
else
|
||||
return w2nn.RandomBinaryCriterion(1, 512):cuda()
|
||||
end
|
||||
elseif settings.loss == "aux_lbp" then
|
||||
if reconstruct.is_rgb(model) then
|
||||
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 512}):cuda()
|
||||
else
|
||||
return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 512}):cuda()
|
||||
end
|
||||
else
|
||||
error("unsupported loss .." .. settings.loss)
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue