1
0
Fork 0
mirror of synced 2024-05-23 14:19:38 +12:00

Fix a bug in ClippedMSECriterion

This commit is contained in:
nagadomi 2016-07-04 11:16:04 +09:00
parent d4dc8204a6
commit 9ec1f5159b
2 changed files with 2 additions and 3 deletions

View file

@ -8,7 +8,7 @@ function ClippedMSECriterion:__init(min, max)
end
function ClippedMSECriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
self.diff[torch.lt(self.diff, self.min)]:clamp(self.min, self.max)
self.diff:clamp(self.min, self.max)
self.diff:add(-1, target)
self.output = self.diff:pow(2):sum() / input:nElement()
return self.output

View file

@ -14,8 +14,7 @@ function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
end
function ClippedWeightedHuberCriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
self.diff:clamp(self.clip[1], self.clip[2])
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i]):cmul(self.weight)
end