diff --git a/lib/AuxiliaryLossCriterion.lua b/lib/AuxiliaryLossCriterion.lua index 81c9682..65b5fb9 100644 --- a/lib/AuxiliaryLossCriterion.lua +++ b/lib/AuxiliaryLossCriterion.lua @@ -5,14 +5,18 @@ function AuxiliaryLossCriterion:__init(base_criterion, args) parent.__init(self) self.base_criterion = base_criterion self.args = args - self.criterions = {} self.gradInput = {} self.sizeAverage = false + self.criterions = {} + if self.base_criterion.has_instance_loss then + self.instance_loss = {} + end end function AuxiliaryLossCriterion:updateOutput(input, target) local sum_output = 0 if type(input) == "table" then -- model:training() + self.output = 0 for i = 1, #input do if self.criterions[i] == nil then if self.args ~= nil then @@ -25,10 +29,22 @@ function AuxiliaryLossCriterion:updateOutput(input, target) self.criterions[i]:cuda() end end - local output = self.criterions[i]:updateOutput(input[i], target) - sum_output = sum_output + output + self.output = self.output + self.criterions[i]:updateOutput(input[i], target) / #input + + if self.instance_loss then + local batch_size = #self.criterions[i].instance_loss + local scale = 1.0 / #input + if i == 1 then + for j = 1, batch_size do + self.instance_loss[j] = self.criterions[i].instance_loss[j] * scale + end + else + for j = 1, batch_size do + self.instance_loss[j] = self.instance_loss[j] + self.criterions[i].instance_loss[j] * scale + end + end + end end - self.output = sum_output / #input else -- model:evaluate() if self.criterions[1] == nil then @@ -43,6 +59,12 @@ function AuxiliaryLossCriterion:updateOutput(input, target) end end self.output = self.criterions[1]:updateOutput(input, target) + if self.instance_loss then + local batch_size = #self.criterions[1].instance_loss + for j = 1, batch_size do + self.instance_loss[j] = self.criterions[1].instance_loss[j] + end + end end return self.output end diff --git a/lib/ClippedMSECriterion.lua b/lib/ClippedMSECriterion.lua index a50b371..4a92db8 100644 --- a/lib/ClippedMSECriterion.lua +++ b/lib/ClippedMSECriterion.lua @@ -1,19 +1,28 @@ local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion') +ClippedMSECriterion.has_instance_loss = true function ClippedMSECriterion:__init(min, max) parent.__init(self) self.min = min or 0 self.max = max or 1 self.diff = torch.Tensor() self.diff_pow2 = torch.Tensor() + self.instance_loss = {} end function ClippedMSECriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) self.diff:clamp(self.min, self.max) self.diff:add(-1, target) self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2) - self.output = self.diff_pow2:sum() / input:nElement() - return self.output + self.instance_loss = {} + self.output = 0 + local scale = 1.0 / input:size(1) + for i = 1, input:size(1) do + local instance_loss = self.diff_pow2[i]:sum() / self.diff_pow2[i]:nElement() + self.instance_loss[i] = instance_loss + self.output = self.output + instance_loss + end + return self.output / input:size(1) end function ClippedMSECriterion:updateGradInput(input, target) local norm = 1.0 / input:nElement() diff --git a/lib/LBPCriterion.lua b/lib/LBPCriterion.lua index 9e7ade7..255af86 100644 --- a/lib/LBPCriterion.lua +++ b/lib/LBPCriterion.lua @@ -55,9 +55,7 @@ function LBPCriterion:updateOutput(input, target) -- huber loss self.diff:resizeAs(lb1):copy(lb1) - for i = 1, lb1:size(1) do - self.diff[i]:add(-1, lb2[i]) - end + self.diff:add(-1, lb2) self.diff_abs:resizeAs(self.diff):copy(self.diff):abs() local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)] diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 2f8a409..564eb69 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -44,29 +44,12 @@ local function minibatch_adam(model, criterion, eval_metric, gradParameters:zero() local output = model:forward(inputs) local f = criterion:forward(output, targets) - local se = 0 + local se = eval_metric:forward(output, targets) if config.xInstanceLoss then - if type(output) then - local tbl = {} - for i = 1, batch_size do - for j = 1, #output do - tbl[j] = output[j][i] - end - local el = eval_metric:forward(tbl, targets[i]) - se = se + el - instance_loss[shuffle[t + i - 1]] = el - end - se = (se / batch_size) - else - for i = 1, batch_size do - local el = eval_metric:forward(output[i], targets[i]) - se = se + el - instance_loss[shuffle[t + i - 1]] = el - end - se = (se / batch_size) - end - else - se = eval_metric:forward(output, targets) + assert(eval_metric.instance_loss, "eval metric does not support instalce_loss") + for i = 1, #eval_metric.instance_loss do + instance_loss[shuffle[t + i - 1]] = eval_metric.instance_loss[i] + end end sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6))) sum_eval = sum_eval + se