From 9ec1f5159b90f497b9f93e817dbc7c96b79dff6c Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 4 Jul 2016 11:16:04 +0900 Subject: [PATCH] Fix a bug in ClippedMSECriterion --- lib/ClippedMSECriterion.lua | 2 +- lib/ClippedWeightedHuberCriterion.lua | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/ClippedMSECriterion.lua b/lib/ClippedMSECriterion.lua index aae28d3..19336ee 100644 --- a/lib/ClippedMSECriterion.lua +++ b/lib/ClippedMSECriterion.lua @@ -8,7 +8,7 @@ function ClippedMSECriterion:__init(min, max) end function ClippedMSECriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) - self.diff[torch.lt(self.diff, self.min)]:clamp(self.min, self.max) + self.diff:clamp(self.min, self.max) self.diff:add(-1, target) self.output = self.diff:pow(2):sum() / input:nElement() return self.output diff --git a/lib/ClippedWeightedHuberCriterion.lua b/lib/ClippedWeightedHuberCriterion.lua index 77f83a4..47d64e5 100644 --- a/lib/ClippedWeightedHuberCriterion.lua +++ b/lib/ClippedWeightedHuberCriterion.lua @@ -14,8 +14,7 @@ function ClippedWeightedHuberCriterion:__init(w, gamma, clip) end function ClippedWeightedHuberCriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) - self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1] - self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2] + self.diff:clamp(self.clip[1], self.clip[2]) for i = 1, input:size(1) do self.diff[i]:add(-1, target[i]):cmul(self.weight) end