2016-05-27 19:56:38 +12:00
|
|
|
local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
|
|
|
|
|
2018-11-05 04:19:09 +13:00
|
|
|
ClippedMSECriterion.has_instance_loss = true
|
2016-05-27 19:56:38 +12:00
|
|
|
function ClippedMSECriterion:__init(min, max)
|
|
|
|
parent.__init(self)
|
2018-10-03 22:34:26 +13:00
|
|
|
self.min = min or 0
|
|
|
|
self.max = max or 1
|
2016-05-27 19:56:38 +12:00
|
|
|
self.diff = torch.Tensor()
|
2017-01-09 16:55:50 +13:00
|
|
|
self.diff_pow2 = torch.Tensor()
|
2018-11-05 04:19:09 +13:00
|
|
|
self.instance_loss = {}
|
2016-05-27 19:56:38 +12:00
|
|
|
end
|
|
|
|
function ClippedMSECriterion:updateOutput(input, target)
|
|
|
|
self.diff:resizeAs(input):copy(input)
|
2016-07-04 14:16:04 +12:00
|
|
|
self.diff:clamp(self.min, self.max)
|
2016-05-27 19:56:38 +12:00
|
|
|
self.diff:add(-1, target)
|
2017-01-09 16:55:50 +13:00
|
|
|
self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
|
2018-11-05 04:19:09 +13:00
|
|
|
self.instance_loss = {}
|
|
|
|
self.output = 0
|
|
|
|
local scale = 1.0 / input:size(1)
|
|
|
|
for i = 1, input:size(1) do
|
|
|
|
local instance_loss = self.diff_pow2[i]:sum() / self.diff_pow2[i]:nElement()
|
|
|
|
self.instance_loss[i] = instance_loss
|
|
|
|
self.output = self.output + instance_loss
|
|
|
|
end
|
|
|
|
return self.output / input:size(1)
|
2016-05-27 19:56:38 +12:00
|
|
|
end
|
|
|
|
function ClippedMSECriterion:updateGradInput(input, target)
|
|
|
|
local norm = 1.0 / input:nElement()
|
|
|
|
self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
|
|
|
|
return self.gradInput
|
|
|
|
end
|