diff --git a/lib/AuxiliaryLossCriterion.lua b/lib/AuxiliaryLossCriterion.lua new file mode 100644 index 0000000..1f531ef --- /dev/null +++ b/lib/AuxiliaryLossCriterion.lua @@ -0,0 +1,51 @@ +require 'nn' +local AuxiliaryLossCriterion, parent = torch.class('w2nn.AuxiliaryLossCriterion','nn.Criterion') + +function AuxiliaryLossCriterion:__init(base_criterion) + parent.__init(self) + self.base_criterion = base_criterion + self.criterions = {} + self.gradInput = {} + self.sizeAverage = false +end +function AuxiliaryLossCriterion:updateOutput(input, target) + local sum_output = 0 + if type(input) == "table" then + -- model:training() + for i = 1, #input do + if self.criterions[i] == nil then + self.criterions[i] = self.base_criterion() + self.criterions[i].sizeAverage = self.sizeAverage + if input[i]:type() == "torch.CudaTensor" then + self.criterions[i]:cuda() + end + end + local output = self.criterions[i]:updateOutput(input[i], target) + sum_output = sum_output + output + end + self.output = sum_output / #input + else + -- model:evaluate() + if self.criterions[1] == nil then + self.criterions[1] = self.base_criterion() + self.criterions[1].sizeAverage = self.sizeAverage() + if input:type() == "torch.CudaTensor" then + self.criterions[1]:cuda() + end + end + self.output = self.criterions[1]:updateOutput(input, target) + end + return self.output +end + +function AuxiliaryLossCriterion:updateGradInput(input, target) + for i=1,#input do + local gradInput = self.criterions[i]:updateGradInput(input[i], target) + self.gradInput[i] = self.gradInput[i] or gradInput.new() + self.gradInput[i]:resizeAs(gradInput):copy(gradInput) + end + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + return self.gradInput +end diff --git a/lib/AuxiliaryLossTable.lua b/lib/AuxiliaryLossTable.lua new file mode 100644 index 0000000..3cc0a34 --- /dev/null +++ b/lib/AuxiliaryLossTable.lua @@ -0,0 +1,40 @@ +require 'nn' +local AuxiliaryLossTable, parent = torch.class('w2nn.AuxiliaryLossTable', 'nn.Module') + +function AuxiliaryLossTable:__init(i) + parent.__init(self) + self.i = i or 1 + self.gradInput = {} + self.output_table = {} + self.output_tensor = torch.Tensor() +end + +function AuxiliaryLossTable:updateOutput(input) + if self.train then + for i=1,#input do + self.output_table[i] = self.output_table[i] or input[1].new() + self.output_table[i]:resizeAs(input[i]):copy(input[i]) + end + for i=#input+1, #self.output_table do + self.output_table[i] = nil + end + self.output = self.output_table + else + self.output_tensor:resizeAs(input[1]):copy(input[1]) + self.output_tensor:copy(input[self.i]) + self.output = self.output_tensor + end + return self.output +end + +function AuxiliaryLossTable:updateGradInput(input, gradOutput) + for i=1,#input do + self.gradInput[i] = self.gradInput[i] or input[1].new() + self.gradInput[i]:resizeAs(input[i]):copy(gradOutput[i]) + end + for i=#input+1, #self.gradInput do + self.gradInput[i] = nil + end + + return self.gradInput +end diff --git a/lib/w2nn.lua b/lib/w2nn.lua index bf20ec1..6a906f6 100644 --- a/lib/w2nn.lua +++ b/lib/w2nn.lua @@ -77,5 +77,8 @@ else require 'ShakeShakeTable' require 'PrintTable' require 'Print' + require 'AuxiliaryLossTable' + require 'AuxiliaryLossCriterion' return w2nn end + diff --git a/train.lua b/train.lua index b2800a3..2fe9091 100644 --- a/train.lua +++ b/train.lua @@ -365,6 +365,10 @@ local function create_criterion(model) local bce = nn.BCECriterion() bce.sizeAverage = true return bce:cuda() + elseif settings.loss == "aux_bce" then + local aux = w2nn.AuxiliaryLossCriterion(nn.BCECriterion) + aux.sizeAverage = true + return aux:cuda() else error("unsupported loss .." .. settings.loss) end @@ -500,7 +504,12 @@ local function train() transform_pool_init(reconstruct.has_resize(model), offset) local criterion = create_criterion(model) - local eval_metric = w2nn.ClippedMSECriterion(0, 1):cuda() + local eval_metric = nil + if settings.loss:find("aux_") ~= nil then + eval_metric = w2nn.AuxiliaryLossCriterion(w2nn.ClippedMSECriterion) + else + eval_metric = w2nn.ClippedMSECriterion() + end local adam_config = { xLearningRate = settings.learning_rate, xBatchSize = settings.batch_size,