41 lines
1.1 KiB
Lua
41 lines
1.1 KiB
Lua
|
require 'nn'
|
||
|
local AuxiliaryLossTable, parent = torch.class('w2nn.AuxiliaryLossTable', 'nn.Module')
|
||
|
|
||
|
function AuxiliaryLossTable:__init(i)
|
||
|
parent.__init(self)
|
||
|
self.i = i or 1
|
||
|
self.gradInput = {}
|
||
|
self.output_table = {}
|
||
|
self.output_tensor = torch.Tensor()
|
||
|
end
|
||
|
|
||
|
function AuxiliaryLossTable:updateOutput(input)
|
||
|
if self.train then
|
||
|
for i=1,#input do
|
||
|
self.output_table[i] = self.output_table[i] or input[1].new()
|
||
|
self.output_table[i]:resizeAs(input[i]):copy(input[i])
|
||
|
end
|
||
|
for i=#input+1, #self.output_table do
|
||
|
self.output_table[i] = nil
|
||
|
end
|
||
|
self.output = self.output_table
|
||
|
else
|
||
|
self.output_tensor:resizeAs(input[1]):copy(input[1])
|
||
|
self.output_tensor:copy(input[self.i])
|
||
|
self.output = self.output_tensor
|
||
|
end
|
||
|
return self.output
|
||
|
end
|
||
|
|
||
|
function AuxiliaryLossTable:updateGradInput(input, gradOutput)
|
||
|
for i=1,#input do
|
||
|
self.gradInput[i] = self.gradInput[i] or input[1].new()
|
||
|
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput[i])
|
||
|
end
|
||
|
for i=#input+1, #self.gradInput do
|
||
|
self.gradInput[i] = nil
|
||
|
end
|
||
|
|
||
|
return self.gradInput
|
||
|
end
|