diff --git a/lib/ClippedMSECriterion.lua b/lib/ClippedMSECriterion.lua index aae28d3..19336ee 100644 --- a/lib/ClippedMSECriterion.lua +++ b/lib/ClippedMSECriterion.lua @@ -8,7 +8,7 @@ function ClippedMSECriterion:__init(min, max) end function ClippedMSECriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) - self.diff[torch.lt(self.diff, self.min)]:clamp(self.min, self.max) + self.diff:clamp(self.min, self.max) self.diff:add(-1, target) self.output = self.diff:pow(2):sum() / input:nElement() return self.output diff --git a/lib/ClippedWeightedHuberCriterion.lua b/lib/ClippedWeightedHuberCriterion.lua index 77f83a4..47d64e5 100644 --- a/lib/ClippedWeightedHuberCriterion.lua +++ b/lib/ClippedWeightedHuberCriterion.lua @@ -14,8 +14,7 @@ function ClippedWeightedHuberCriterion:__init(w, gamma, clip) end function ClippedWeightedHuberCriterion:updateOutput(input, target) self.diff:resizeAs(input):copy(input) - self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1] - self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2] + self.diff:clamp(self.clip[1], self.clip[2]) for i = 1, input:size(1) do self.diff[i]:add(-1, target[i]):cmul(self.weight) end