1
0
Fork 0
mirror of synced 2024-10-04 04:06:09 +13:00
waifu2x/lib/RGBWeightedMSECriterion.lua
nagadomi 8dea362bed sync from internal repo
- Memory compression by snappy (lua-csnappy)
- Use RGB-wise Weighted MSE(R*0.299, G*0.587, B*0.114) instead of MSE
- Aggressive cropping for edge region
and some change.
2015-10-26 09:23:52 +09:00

25 lines
728 B
Lua

local RGBWeightedMSECriterion, parent = torch.class('mynn.RGBWeightedMSECriterion','nn.Criterion')
function RGBWeightedMSECriterion:__init(w)
parent.__init(self)
self.weight = w:clone()
self.diff = torch.Tensor()
self.loss = torch.Tensor()
end
function RGBWeightedMSECriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i]):cmul(self.weight)
end
self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
self.output = self.loss:mean()
return self.output
end
function RGBWeightedMSECriterion:updateGradInput(input, target)
self.gradInput:resizeAs(input):copy(self.diff)
return self.gradInput
end