2015-10-28 19:30:47 +13:00
|
|
|
local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
|
2015-10-26 13:23:52 +13:00
|
|
|
|
2015-10-28 19:30:47 +13:00
|
|
|
function WeightedMSECriterion:__init(w)
|
2015-10-26 13:23:52 +13:00
|
|
|
parent.__init(self)
|
|
|
|
self.weight = w:clone()
|
|
|
|
self.diff = torch.Tensor()
|
|
|
|
self.loss = torch.Tensor()
|
|
|
|
end
|
|
|
|
|
2015-10-28 19:30:47 +13:00
|
|
|
function WeightedMSECriterion:updateOutput(input, target)
|
2015-10-26 13:23:52 +13:00
|
|
|
self.diff:resizeAs(input):copy(input)
|
|
|
|
for i = 1, input:size(1) do
|
|
|
|
self.diff[i]:add(-1, target[i]):cmul(self.weight)
|
|
|
|
end
|
|
|
|
self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
|
|
|
|
self.output = self.loss:mean()
|
|
|
|
|
|
|
|
return self.output
|
|
|
|
end
|
|
|
|
|
2015-10-28 19:30:47 +13:00
|
|
|
function WeightedMSECriterion:updateGradInput(input, target)
|
2015-10-31 08:35:33 +13:00
|
|
|
local norm = 2.0 / input:nElement()
|
|
|
|
self.gradInput:resizeAs(input):copy(self.diff):mul(norm)
|
2015-10-26 13:23:52 +13:00
|
|
|
return self.gradInput
|
|
|
|
end
|