Add support for criterion arguments in AuxiliaryLossCriterion
This commit is contained in:
parent
0883b043ce
commit
aef969d64b
|
@ -1,9 +1,10 @@
|
|||
require 'nn'
|
||||
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
|
||||
|
||||
function AuxiliaryLossCriterion:__init(base_criterion)
|
||||
function AuxiliaryLossCriterion:__init(base_criterion, args)
|
||||
parent.__init(self)
|
||||
self.base_criterion = base_criterion
|
||||
self.args = args
|
||||
self.criterions = {}
|
||||
self.gradInput = {}
|
||||
self.sizeAverage = false
|
||||
|
@ -14,7 +15,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|||
-- model:training()
|
||||
for i = 1, #input do
|
||||
if self.criterions[i] == nil then
|
||||
self.criterions[i] = self.base_criterion()
|
||||
if self.args ~= nil then
|
||||
self.criterions[i] = self.base_criterion(table.unpack(self.args))
|
||||
else
|
||||
self.criterions[i] = self.base_criterion()
|
||||
end
|
||||
self.criterions[i].sizeAverage = self.sizeAverage
|
||||
if input[i]:type() == "torch.CudaTensor" then
|
||||
self.criterions[i]:cuda()
|
||||
|
@ -27,7 +32,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|||
else
|
||||
-- model:evaluate()
|
||||
if self.criterions[1] == nil then
|
||||
self.criterions[1] = self.base_criterion()
|
||||
if self.args ~= nil then
|
||||
self.criterions[1] = self.base_criterion(table.unpack(self.args))
|
||||
else
|
||||
self.criterions[1] = self.base_criterion()
|
||||
end
|
||||
self.criterions[1].sizeAverage = self.sizeAverage()
|
||||
if input:type() == "torch.CudaTensor" then
|
||||
self.criterions[1]:cuda()
|
||||
|
|
Loading…
Reference in a new issue