Add support for criterion arguments in AuxiliaryLossCriterion
This commit is contained in:
parent
0883b043ce
commit
aef969d64b
|
@ -1,9 +1,10 @@
|
||||||
require 'nn'
|
require 'nn'
|
||||||
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
|
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
|
||||||
|
|
||||||
function AuxiliaryLossCriterion:__init(base_criterion)
|
function AuxiliaryLossCriterion:__init(base_criterion, args)
|
||||||
parent.__init(self)
|
parent.__init(self)
|
||||||
self.base_criterion = base_criterion
|
self.base_criterion = base_criterion
|
||||||
|
self.args = args
|
||||||
self.criterions = {}
|
self.criterions = {}
|
||||||
self.gradInput = {}
|
self.gradInput = {}
|
||||||
self.sizeAverage = false
|
self.sizeAverage = false
|
||||||
|
@ -14,7 +15,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
||||||
-- model:training()
|
-- model:training()
|
||||||
for i = 1, #input do
|
for i = 1, #input do
|
||||||
if self.criterions[i] == nil then
|
if self.criterions[i] == nil then
|
||||||
self.criterions[i] = self.base_criterion()
|
if self.args ~= nil then
|
||||||
|
self.criterions[i] = self.base_criterion(table.unpack(self.args))
|
||||||
|
else
|
||||||
|
self.criterions[i] = self.base_criterion()
|
||||||
|
end
|
||||||
self.criterions[i].sizeAverage = self.sizeAverage
|
self.criterions[i].sizeAverage = self.sizeAverage
|
||||||
if input[i]:type() == "torch.CudaTensor" then
|
if input[i]:type() == "torch.CudaTensor" then
|
||||||
self.criterions[i]:cuda()
|
self.criterions[i]:cuda()
|
||||||
|
@ -27,7 +32,11 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
||||||
else
|
else
|
||||||
-- model:evaluate()
|
-- model:evaluate()
|
||||||
if self.criterions[1] == nil then
|
if self.criterions[1] == nil then
|
||||||
self.criterions[1] = self.base_criterion()
|
if self.args ~= nil then
|
||||||
|
self.criterions[1] = self.base_criterion(table.unpack(self.args))
|
||||||
|
else
|
||||||
|
self.criterions[1] = self.base_criterion()
|
||||||
|
end
|
||||||
self.criterions[1].sizeAverage = self.sizeAverage()
|
self.criterions[1].sizeAverage = self.sizeAverage()
|
||||||
if input:type() == "torch.CudaTensor" then
|
if input:type() == "torch.CudaTensor" then
|
||||||
self.criterions[1]:cuda()
|
self.criterions[1]:cuda()
|
||||||
|
|
Loading…
Reference in a new issue