1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add support for criterion arguments in AuxiliaryLossCriterion

This commit is contained in:
nagadomi 2018-10-14 01:15:51 +09:00
parent 0883b043ce
commit aef969d64b

View file

@ -1,9 +1,10 @@
require 'nn'
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
function AuxiliaryLossCriterion:__init(base_criterion)
function AuxiliaryLossCriterion:__init(base_criterion, args)
parent.__init(self)
self.base_criterion = base_criterion
self.args = args
self.criterions = {}
self.gradInput = {}
self.sizeAverage = false
@ -14,7 +15,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
-- model:training()
for i = 1, #input do
if self.criterions[i] == nil then
self.criterions[i] = self.base_criterion()
if self.args ~= nil then
self.criterions[i] = self.base_criterion(table.unpack(self.args))
else
self.criterions[i] = self.base_criterion()
end
self.criterions[i].sizeAverage = self.sizeAverage
if input[i]:type() == "torch.CudaTensor" then
self.criterions[i]:cuda()
@ -27,7 +32,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
else
-- model:evaluate()
if self.criterions[1] == nil then
self.criterions[1] = self.base_criterion()
if self.args ~= nil then
self.criterions[1] = self.base_criterion(table.unpack(self.args))
else
self.criterions[1] = self.base_criterion()
end
self.criterions[1].sizeAverage = self.sizeAverage()
if input:type() == "torch.CudaTensor" then
self.criterions[1]:cuda()