performance tuning
This commit is contained in:
parent
98a091a7cb
commit
86d8fe96da
|
@ -5,14 +5,18 @@ function AuxiliaryLossCriterion:__init(base_criterion, args)
|
|||
parent.__init(self)
|
||||
self.base_criterion = base_criterion
|
||||
self.args = args
|
||||
self.criterions = {}
|
||||
self.gradInput = {}
|
||||
self.sizeAverage = false
|
||||
self.criterions = {}
|
||||
if self.base_criterion.has_instance_loss then
|
||||
self.instance_loss = {}
|
||||
end
|
||||
end
|
||||
function AuxiliaryLossCriterion:updateOutput(input, target)
|
||||
local sum_output = 0
|
||||
if type(input) == "table" then
|
||||
-- model:training()
|
||||
self.output = 0
|
||||
for i = 1, #input do
|
||||
if self.criterions[i] == nil then
|
||||
if self.args ~= nil then
|
||||
|
@ -25,10 +29,22 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|||
self.criterions[i]:cuda()
|
||||
end
|
||||
end
|
||||
local output = self.criterions[i]:updateOutput(input[i], target)
|
||||
sum_output = sum_output + output
|
||||
self.output = self.output + self.criterions[i]:updateOutput(input[i], target) / #input
|
||||
|
||||
if self.instance_loss then
|
||||
local batch_size = #self.criterions[i].instance_loss
|
||||
local scale = 1.0 / #input
|
||||
if i == 1 then
|
||||
for j = 1, batch_size do
|
||||
self.instance_loss[j] = self.criterions[i].instance_loss[j] * scale
|
||||
end
|
||||
else
|
||||
for j = 1, batch_size do
|
||||
self.instance_loss[j] = self.instance_loss[j] + self.criterions[i].instance_loss[j] * scale
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
self.output = sum_output / #input
|
||||
else
|
||||
-- model:evaluate()
|
||||
if self.criterions[1] == nil then
|
||||
|
@ -43,6 +59,12 @@ function AuxiliaryLossCriterion:updateOutput(input, target)
|
|||
end
|
||||
end
|
||||
self.output = self.criterions[1]:updateOutput(input, target)
|
||||
if self.instance_loss then
|
||||
local batch_size = #self.criterions[1].instance_loss
|
||||
for j = 1, batch_size do
|
||||
self.instance_loss[j] = self.criterions[1].instance_loss[j]
|
||||
end
|
||||
end
|
||||
end
|
||||
return self.output
|
||||
end
|
||||
|
|
|
@ -1,19 +1,28 @@
|
|||
local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
|
||||
|
||||
ClippedMSECriterion.has_instance_loss = true
|
||||
function ClippedMSECriterion:__init(min, max)
|
||||
parent.__init(self)
|
||||
self.min = min or 0
|
||||
self.max = max or 1
|
||||
self.diff = torch.Tensor()
|
||||
self.diff_pow2 = torch.Tensor()
|
||||
self.instance_loss = {}
|
||||
end
|
||||
function ClippedMSECriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
self.diff:clamp(self.min, self.max)
|
||||
self.diff:add(-1, target)
|
||||
self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
|
||||
self.output = self.diff_pow2:sum() / input:nElement()
|
||||
return self.output
|
||||
self.instance_loss = {}
|
||||
self.output = 0
|
||||
local scale = 1.0 / input:size(1)
|
||||
for i = 1, input:size(1) do
|
||||
local instance_loss = self.diff_pow2[i]:sum() / self.diff_pow2[i]:nElement()
|
||||
self.instance_loss[i] = instance_loss
|
||||
self.output = self.output + instance_loss
|
||||
end
|
||||
return self.output / input:size(1)
|
||||
end
|
||||
function ClippedMSECriterion:updateGradInput(input, target)
|
||||
local norm = 1.0 / input:nElement()
|
||||
|
|
|
@ -55,9 +55,7 @@ function LBPCriterion:updateOutput(input, target)
|
|||
|
||||
-- huber loss
|
||||
self.diff:resizeAs(lb1):copy(lb1)
|
||||
for i = 1, lb1:size(1) do
|
||||
self.diff[i]:add(-1, lb2[i])
|
||||
end
|
||||
self.diff:add(-1, lb2)
|
||||
self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
|
||||
|
||||
local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]
|
||||
|
|
|
@ -44,29 +44,12 @@ local function minibatch_adam(model, criterion, eval_metric,
|
|||
gradParameters:zero()
|
||||
local output = model:forward(inputs)
|
||||
local f = criterion:forward(output, targets)
|
||||
local se = 0
|
||||
local se = eval_metric:forward(output, targets)
|
||||
if config.xInstanceLoss then
|
||||
if type(output) then
|
||||
local tbl = {}
|
||||
for i = 1, batch_size do
|
||||
for j = 1, #output do
|
||||
tbl[j] = output[j][i]
|
||||
end
|
||||
local el = eval_metric:forward(tbl, targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
end
|
||||
se = (se / batch_size)
|
||||
else
|
||||
for i = 1, batch_size do
|
||||
local el = eval_metric:forward(output[i], targets[i])
|
||||
se = se + el
|
||||
instance_loss[shuffle[t + i - 1]] = el
|
||||
end
|
||||
se = (se / batch_size)
|
||||
end
|
||||
else
|
||||
se = eval_metric:forward(output, targets)
|
||||
assert(eval_metric.instance_loss, "eval metric does not support instalce_loss")
|
||||
for i = 1, #eval_metric.instance_loss do
|
||||
instance_loss[shuffle[t + i - 1]] = eval_metric.instance_loss[i]
|
||||
end
|
||||
end
|
||||
sum_psnr = sum_psnr + (10 * math.log10(1 / (se + 1.0e-6)))
|
||||
sum_eval = sum_eval + se
|
||||
|
|
Loading…
Reference in a new issue