1
0
Fork 0
mirror of synced 2024-06-26 18:20:26 +12:00
waifu2x/lib/ClippedMSECriterion.lua

23 lines
779 B
Lua
Raw Normal View History

2016-05-27 19:56:38 +12:00
local ClippedMSECriterion, parent = torch.class('w2nn.ClippedMSECriterion','nn.Criterion')
function ClippedMSECriterion:__init(min, max)
parent.__init(self)
self.min = min or 0
self.max = max or 1
2016-05-27 19:56:38 +12:00
self.diff = torch.Tensor()
self.diff_pow2 = torch.Tensor()
2016-05-27 19:56:38 +12:00
end
function ClippedMSECriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
2016-07-04 14:16:04 +12:00
self.diff:clamp(self.min, self.max)
2016-05-27 19:56:38 +12:00
self.diff:add(-1, target)
self.diff_pow2:resizeAs(self.diff):copy(self.diff):pow(2)
self.output = self.diff_pow2:sum() / input:nElement()
2016-05-27 19:56:38 +12:00
return self.output
end
function ClippedMSECriterion:updateGradInput(input, target)
local norm = 1.0 / input:nElement()
self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
return self.gradInput
end