From 89c8f5db8e3672dd95edc2b98bac8c8c2794f58f Mon Sep 17 00:00:00 2001 From: nagadomi Date: Tue, 16 Oct 2018 13:02:09 +0000 Subject: [PATCH] Add aux_lbp; Fix oracle --- lib/minibatch_adam.lua | 25 +++++++++++++++++++------ train.lua | 6 ++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index a6f7abb..2f8a409 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -46,12 +46,25 @@ local function minibatch_adam(model, criterion, eval_metric, local f = criterion:forward(output, targets) local se = 0 if config.xInstanceLoss then - for i = 1, batch_size do - local el = eval_metric:forward(output[i], targets[i]) - se = se + el - instance_loss[shuffle[t + i - 1]] = el - end - se = (se / batch_size) + if type(output) then + local tbl = {} + for i = 1, batch_size do + for j = 1, #output do + tbl[j] = output[j][i] + end + local el = eval_metric:forward(tbl, targets[i]) + se = se + el + instance_loss[shuffle[t + i - 1]] = el + end + se = (se / batch_size) + else + for i = 1, batch_size do + local el = eval_metric:forward(output[i], targets[i]) + se = se + el + instance_loss[shuffle[t + i - 1]] = el + end + se = (se / batch_size) + end else se = eval_metric:forward(output, targets) end diff --git a/train.lua b/train.lua index 641b06c..8b051ec 100644 --- a/train.lua +++ b/train.lua @@ -394,6 +394,12 @@ local function create_criterion(model) else return w2nn.RandomBinaryCriterion(1, 512):cuda() end + elseif settings.loss == "aux_lbp" then + if reconstruct.is_rgb(model) then + return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {3, 512}):cuda() + else + return w2nn.AuxiliaryLossCriterion(w2nn.RandomBinaryCriterion, {1, 512}):cuda() + end else error("unsupported loss .." .. settings.loss) end