Add support for AuxiliaryLoss
This commit is contained in:
parent
cf638aea48
commit
56536ac133
51
lib/AuxiliaryLossCriterion.lua
Normal file
51
lib/AuxiliaryLossCriterion.lua
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
require 'nn'
|
||||||
|
local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion')
|
||||||
|
|
||||||
|
function AuxiliaryLossCriterion:__init(base_criterion)
|
||||||
|
parent.__init(self)
|
||||||
|
self.base_criterion = base_criterion
|
||||||
|
self.criterions = {}
|
||||||
|
self.gradInput = {}
|
||||||
|
self.sizeAverage = false
|
||||||
|
end
|
||||||
|
function AuxiliaryLossCriterion:updateOutput(input, target)
|
||||||
|
local sum_output = 0
|
||||||
|
if type(input) == "table" then
|
||||||
|
-- model:training()
|
||||||
|
for i = 1, #input do
|
||||||
|
if self.criterions[i] == nil then
|
||||||
|
self.criterions[i] = self.base_criterion()
|
||||||
|
self.criterions[i].sizeAverage = self.sizeAverage
|
||||||
|
if input[i]:type() == "torch.CudaTensor" then
|
||||||
|
self.criterions[i]:cuda()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
local output = self.criterions[i]:updateOutput(input[i], target)
|
||||||
|
sum_output = sum_output + output
|
||||||
|
end
|
||||||
|
self.output = sum_output / #input
|
||||||
|
else
|
||||||
|
-- model:evaluate()
|
||||||
|
if self.criterions[1] == nil then
|
||||||
|
self.criterions[1] = self.base_criterion()
|
||||||
|
self.criterions[1].sizeAverage = self.sizeAverage()
|
||||||
|
if input:type() == "torch.CudaTensor" then
|
||||||
|
self.criterions[1]:cuda()
|
||||||
|
end
|
||||||
|
end
|
||||||
|
self.output = self.criterions[1]:updateOutput(input, target)
|
||||||
|
end
|
||||||
|
return self.output
|
||||||
|
end
|
||||||
|
|
||||||
|
function AuxiliaryLossCriterion:updateGradInput(input, target)
|
||||||
|
for i=1,#input do
|
||||||
|
local gradInput = self.criterions[i]:updateGradInput(input[i], target)
|
||||||
|
self.gradInput[i] = self.gradInput[i] or gradInput.new()
|
||||||
|
self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
|
||||||
|
end
|
||||||
|
for i=#input+1, #self.gradInput do
|
||||||
|
self.gradInput[i] = nil
|
||||||
|
end
|
||||||
|
return self.gradInput
|
||||||
|
end
|
40
lib/AuxiliaryLossTable.lua
Normal file
40
lib/AuxiliaryLossTable.lua
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
require 'nn'
|
||||||
|
local AuxiliaryLossTable, parent = torch.class('w2nn.AuxiliaryLossTable', 'nn.Module')
|
||||||
|
|
||||||
|
function AuxiliaryLossTable:__init(i)
|
||||||
|
parent.__init(self)
|
||||||
|
self.i = i or 1
|
||||||
|
self.gradInput = {}
|
||||||
|
self.output_table = {}
|
||||||
|
self.output_tensor = torch.Tensor()
|
||||||
|
end
|
||||||
|
|
||||||
|
function AuxiliaryLossTable:updateOutput(input)
|
||||||
|
if self.train then
|
||||||
|
for i=1,#input do
|
||||||
|
self.output_table[i] = self.output_table[i] or input[1].new()
|
||||||
|
self.output_table[i]:resizeAs(input[i]):copy(input[i])
|
||||||
|
end
|
||||||
|
for i=#input+1, #self.output_table do
|
||||||
|
self.output_table[i] = nil
|
||||||
|
end
|
||||||
|
self.output = self.output_table
|
||||||
|
else
|
||||||
|
self.output_tensor:resizeAs(input[1]):copy(input[1])
|
||||||
|
self.output_tensor:copy(input[self.i])
|
||||||
|
self.output = self.output_tensor
|
||||||
|
end
|
||||||
|
return self.output
|
||||||
|
end
|
||||||
|
|
||||||
|
function AuxiliaryLossTable:updateGradInput(input, gradOutput)
|
||||||
|
for i=1,#input do
|
||||||
|
self.gradInput[i] = self.gradInput[i] or input[1].new()
|
||||||
|
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput[i])
|
||||||
|
end
|
||||||
|
for i=#input+1, #self.gradInput do
|
||||||
|
self.gradInput[i] = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
return self.gradInput
|
||||||
|
end
|
|
@ -77,5 +77,8 @@ else
|
||||||
require 'ShakeShakeTable'
|
require 'ShakeShakeTable'
|
||||||
require 'PrintTable'
|
require 'PrintTable'
|
||||||
require 'Print'
|
require 'Print'
|
||||||
|
require 'AuxiliaryLossTable'
|
||||||
|
require 'AuxiliaryLossCriterion'
|
||||||
return w2nn
|
return w2nn
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
11
train.lua
11
train.lua
|
@ -365,6 +365,10 @@ local function create_criterion(model)
|
||||||
local bce = nn.BCECriterion()
|
local bce = nn.BCECriterion()
|
||||||
bce.sizeAverage = true
|
bce.sizeAverage = true
|
||||||
return bce:cuda()
|
return bce:cuda()
|
||||||
|
elseif settings.loss == "aux_bce" then
|
||||||
|
local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion)
|
||||||
|
aux.sizeAverage = true
|
||||||
|
return aux:cuda()
|
||||||
else
|
else
|
||||||
error("unsupported loss .." .. settings.loss)
|
error("unsupported loss .." .. settings.loss)
|
||||||
end
|
end
|
||||||
|
@ -500,7 +504,12 @@ local function train()
|
||||||
transform_pool_init(reconstruct.has_resize(model), offset)
|
transform_pool_init(reconstruct.has_resize(model), offset)
|
||||||
|
|
||||||
local criterion = create_criterion(model)
|
local criterion = create_criterion(model)
|
||||||
local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda()
|
local eval_metric = nil
|
||||||
|
if settings.loss:find("aux_") ~= nil then
|
||||||
|
eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion)
|
||||||
|
else
|
||||||
|
eval_metric = w2nn.ClippedMSECriterion()
|
||||||
|
end
|
||||||
local adam_config = {
|
local adam_config = {
|
||||||
xLearningRate = settings.learning_rate,
|
xLearningRate = settings.learning_rate,
|
||||||
xBatchSize = settings.batch_size,
|
xBatchSize = settings.batch_size,
|
||||||
|
|
Loading…
Reference in a new issue