Use roundf-like clip for 8 bit-depth image
Maybe PSNR +0.03 improved by this commit
This commit is contained in:
parent
bd63f99b59
commit
797b45ae23
|
@ -1,8 +1,9 @@
|
|||
-- ref: https://en.wikipedia.org/wiki/Huber_loss
|
||||
local WeightedHuberCriterion, parent = torch.class('w2nn.WeightedHuberCriterion','nn.Criterion')
|
||||
local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion')
|
||||
|
||||
function WeightedHuberCriterion:__init(w, gamma)
|
||||
function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
|
||||
parent.__init(self)
|
||||
self.clip = clip
|
||||
self.gamma = gamma or 1.0
|
||||
self.weight = w:clone()
|
||||
self.diff = torch.Tensor()
|
||||
|
@ -11,8 +12,10 @@ function WeightedHuberCriterion:__init(w, gamma)
|
|||
self.square_loss_buff = torch.Tensor()
|
||||
self.linear_loss_buff = torch.Tensor()
|
||||
end
|
||||
function WeightedHuberCriterion:updateOutput(input, target)
|
||||
function ClippedWeightedHuberCriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
|
||||
self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i]):cmul(self.weight)
|
||||
end
|
||||
|
@ -27,7 +30,7 @@ function WeightedHuberCriterion:updateOutput(input, target)
|
|||
self.output = (square_loss + linear_loss) / input:nElement()
|
||||
return self.output
|
||||
end
|
||||
function WeightedHuberCriterion:updateGradInput(input, target)
|
||||
function ClippedWeightedHuberCriterion:updateGradInput(input, target)
|
||||
local norm = 1.0 / input:nElement()
|
||||
self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
|
||||
local outlier = torch.ge(self.diff_abs, self.gamma)
|
|
@ -4,6 +4,9 @@ require 'pl'
|
|||
|
||||
local image_loader = {}
|
||||
|
||||
local clip_eta8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||
local clip_eta16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
||||
|
||||
function image_loader.decode_float(blob)
|
||||
local im, alpha = image_loader.decode_byte(blob)
|
||||
if im then
|
||||
|
@ -25,13 +28,30 @@ function image_loader.encode_png(rgb, alpha, depth)
|
|||
rgba[2]:copy(rgb[2])
|
||||
rgba[3]:copy(rgb[3])
|
||||
rgba[4]:copy(alpha)
|
||||
|
||||
if depth < 16 then
|
||||
rgba:add(clip_eta8)
|
||||
rgba[torch.lt(rgba, 0.0)] = 0.0
|
||||
rgba[torch.gt(rgba, 1.0)] = 1.0
|
||||
else
|
||||
rgba:add(clip_eta16)
|
||||
rgba[torch.lt(rgba, 0.0)] = 0.0
|
||||
rgba[torch.gt(rgba, 1.0)] = 1.0
|
||||
end
|
||||
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
||||
im:format("png")
|
||||
return im:depth(depth):toBlob(9)
|
||||
return im:depth(depth):format("PNG"):toBlob(9)
|
||||
else
|
||||
if depth < 16 then
|
||||
rgb = rgb:clone():add(clip_eta8)
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
else
|
||||
rgb = rgb:clone():add(clip_eta16)
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
end
|
||||
local im = gm.Image(rgb, "RGB", "DHW")
|
||||
im:format("png")
|
||||
return im:depth(depth):toBlob(9)
|
||||
return im:depth(depth):format("PNG"):toBlob(9)
|
||||
end
|
||||
end
|
||||
function image_loader.save_png(filename, rgb, alpha, depth)
|
||||
|
|
|
@ -20,7 +20,7 @@ else
|
|||
require 'LeakyReLU_deprecated'
|
||||
require 'DepthExpand2x'
|
||||
require 'WeightedMSECriterion'
|
||||
require 'WeightedHuberCriterion'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
require 'cleanup_model'
|
||||
return w2nn
|
||||
end
|
||||
|
|
|
@ -76,7 +76,7 @@ local function create_criterion(model)
|
|||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
else
|
||||
return nn.MSECriterion():cuda()
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue