1
0
Fork 0
mirror of synced 2024-06-01 10:39:30 +12:00

Use roundf-like clip for 8 bit-depth image

Maybe PSNR +0.03 improved by this commit
This commit is contained in:
nagadomi 2015-11-08 05:44:14 +09:00
parent bd63f99b59
commit 797b45ae23
4 changed files with 33 additions and 10 deletions

View file

@ -1,8 +1,9 @@
-- ref: https://en.wikipedia.org/wiki/Huber_loss
local WeightedHuberCriterion, parent = torch.class('w2nn.WeightedHuberCriterion','nn.Criterion')
local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion')
function WeightedHuberCriterion:__init(w, gamma)
function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
parent.__init(self)
self.clip = clip
self.gamma = gamma or 1.0
self.weight = w:clone()
self.diff = torch.Tensor()
@ -11,8 +12,10 @@ function WeightedHuberCriterion:__init(w, gamma)
self.square_loss_buff = torch.Tensor()
self.linear_loss_buff = torch.Tensor()
end
function WeightedHuberCriterion:updateOutput(input, target)
function ClippedWeightedHuberCriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i]):cmul(self.weight)
end
@ -27,7 +30,7 @@ function WeightedHuberCriterion:updateOutput(input, target)
self.output = (square_loss + linear_loss) / input:nElement()
return self.output
end
function WeightedHuberCriterion:updateGradInput(input, target)
function ClippedWeightedHuberCriterion:updateGradInput(input, target)
local norm = 1.0 / input:nElement()
self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
local outlier = torch.ge(self.diff_abs, self.gamma)

View file

@ -4,6 +4,9 @@ require 'pl'
local image_loader = {}
local clip_eta8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
local clip_eta16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
function image_loader.decode_float(blob)
local im, alpha = image_loader.decode_byte(blob)
if im then
@ -25,13 +28,30 @@ function image_loader.encode_png(rgb, alpha, depth)
rgba[2]:copy(rgb[2])
rgba[3]:copy(rgb[3])
rgba[4]:copy(alpha)
if depth < 16 then
rgba:add(clip_eta8)
rgba[torch.lt(rgba, 0.0)] = 0.0
rgba[torch.gt(rgba, 1.0)] = 1.0
else
rgba:add(clip_eta16)
rgba[torch.lt(rgba, 0.0)] = 0.0
rgba[torch.gt(rgba, 1.0)] = 1.0
end
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
im:format("png")
return im:depth(depth):toBlob(9)
return im:depth(depth):format("PNG"):toBlob(9)
else
if depth < 16 then
rgb = rgb:clone():add(clip_eta8)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
else
rgb = rgb:clone():add(clip_eta16)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
end
local im = gm.Image(rgb, "RGB", "DHW")
im:format("png")
return im:depth(depth):toBlob(9)
return im:depth(depth):format("PNG"):toBlob(9)
end
end
function image_loader.save_png(filename, rgb, alpha, depth)

View file

@ -20,7 +20,7 @@ else
require 'LeakyReLU_deprecated'
require 'DepthExpand2x'
require 'WeightedMSECriterion'
require 'WeightedHuberCriterion'
require 'ClippedWeightedHuberCriterion'
require 'cleanup_model'
return w2nn
end

View file

@ -76,7 +76,7 @@ local function create_criterion(model)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.WeightedHuberCriterion(weight, 0.1):cuda()
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
return nn.MSECriterion():cuda()
end