1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add support for AuxiliaryLoss

This commit is contained in:
nagadomi 2018-10-03 18:35:38 +09:00
parent cf638aea48
commit 56536ac133
4 changed files with 104 additions and 1 deletions

View file

@ -0,0 +1,51 @@
require 'nn'
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
function AuxiliaryLossCriterion:__init(base_criterion)
parent.__init(self)
self.base_criterion = base_criterion
self.criterions = {}
self.gradInput = {}
self.sizeAverage = false
end
function AuxiliaryLossCriterion:updateOutput(input, target)
local sum_output = 0
if type(input) == "table" then
-- model:training()
for i = 1, #input do
if self.criterions[i] == nil then
self.criterions[i] = self.base_criterion()
self.criterions[i].sizeAverage = self.sizeAverage
if input[i]:type() == "torch.CudaTensor" then
self.criterions[i]:cuda()
end
end
local output = self.criterions[i]:updateOutput(input[i], target)
sum_output = sum_output + output
end
self.output = sum_output / #input
else
-- model:evaluate()
if self.criterions[1] == nil then
self.criterions[1] = self.base_criterion()
self.criterions[1].sizeAverage = self.sizeAverage()
if input:type() == "torch.CudaTensor" then
self.criterions[1]:cuda()
end
end
self.output = self.criterions[1]:updateOutput(input, target)
end
return self.output
end
function AuxiliaryLossCriterion:updateGradInput(input, target)
for i=1,#input do
local gradInput = self.criterions[i]:updateGradInput(input[i], target)
self.gradInput[i] = self.gradInput[i] or gradInput.new()
self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
end
for i=#input+1, #self.gradInput do
self.gradInput[i] = nil
end
return self.gradInput
end

View file

@ -0,0 +1,40 @@
require 'nn'
local AuxiliaryLossTable, parent = torch.class('w2nn.AuxiliaryLossTable', 'nn.Module')
function AuxiliaryLossTable:__init(i)
parent.__init(self)
self.i = i or 1
self.gradInput = {}
self.output_table = {}
self.output_tensor = torch.Tensor()
end
function AuxiliaryLossTable:updateOutput(input)
if self.train then
for i=1,#input do
self.output_table[i] = self.output_table[i] or input[1].new()
self.output_table[i]:resizeAs(input[i]):copy(input[i])
end
for i=#input+1, #self.output_table do
self.output_table[i] = nil
end
self.output = self.output_table
else
self.output_tensor:resizeAs(input[1]):copy(input[1])
self.output_tensor:copy(input[self.i])
self.output = self.output_tensor
end
return self.output
end
function AuxiliaryLossTable:updateGradInput(input, gradOutput)
for i=1,#input do
self.gradInput[i] = self.gradInput[i] or input[1].new()
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput[i])
end
for i=#input+1, #self.gradInput do
self.gradInput[i] = nil
end
return self.gradInput
end

View file

@ -77,5 +77,8 @@ else
require 'ShakeShakeTable'
require 'PrintTable'
require 'Print'
require 'AuxiliaryLossTable'
require 'AuxiliaryLossCriterion'
return w2nn
end

View file

@ -365,6 +365,10 @@ local function create_criterion(model)
local bce = nn.BCECriterion()
bce.sizeAverage = true
return bce:cuda()
elseif settings.loss == "aux_bce" then
local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
aux.sizeAverage = true
return aux:cuda()
else
error("unsupported loss .." .. settings.loss)
end
@ -500,7 +504,12 @@ local function train()
transform_pool_init(reconstruct.has_resize(model), offset)
local criterion = create_criterion(model)
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
local eval_metric = nil
if settings.loss:find("aux_") ~= nil then
eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
else
eval_metric = w2nn.ClippedMSECriterion()
end
local adam_config = {
xLearningRate = settings.learning_rate,
xBatchSize = settings.batch_size,