diff --git a/lib/AuxiliaryLossCriterion.lua b/lib/AuxiliaryLossCriterion.lua index 1f531ef..63170fd 100644 --- a/lib/AuxiliaryLossCriterion.lua +++ b/lib/AuxiliaryLossCriterion.lua @@ -1,9 +1,10 @@ require 'nn' local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion') -function AuxiliaryLossCriterion:__init(base_criterion) +function AuxiliaryLossCriterion:__init(base_criterion, args) parent.__init(self) self.base_criterion = base_criterion + self.args = args self.criterions = {} self.gradInput = {} self.sizeAverage = false @@ -14,7 +15,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target) -- model:training() for i = 1, #input do if self.criterions[i] == nil then - self.criterions[i] = self.base_criterion() + if self.args ~= nil then + self.criterions[i] = self.base_criterion(table.unpack(self.args)) + else + self.criterions[i] = self.base_criterion() + end self.criterions[i].sizeAverage = self.sizeAverage if input[i]:type() == "torch.CudaTensor" then self.criterions[i]:cuda() @@ -27,7 +32,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target) else -- model:evaluate() if self.criterions[1] == nil then - self.criterions[1] = self.base_criterion() + if self.args ~= nil then + self.criterions[1] = self.base_criterion(table.unpack(self.args)) + else + self.criterions[1] = self.base_criterion() + end self.criterions[1].sizeAverage = self.sizeAverage() if input:type() == "torch.CudaTensor" then self.criterions[1]:cuda()