diff --git a/lib/alpha_util.lua b/lib/alpha_util.lua index 9105f06..6a0228b 100644 --- a/lib/alpha_util.lua +++ b/lib/alpha_util.lua @@ -5,15 +5,15 @@ local iproc = require 'iproc' local gm = require 'graphicsmagick' alpha_util = {} -alpha_util.sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda() -alpha_util.sum2d.weight:fill(1) -alpha_util.sum2d.bias:zero() function alpha_util.make_border(rgb, alpha, offset) if not alpha then return rgb end - + local sum2d = nn.SpatialConvolutionMM(1, 1, 3, 3, 1, 1, 1, 1):cuda() + sum2d.weight:fill(1) + sum2d.bias:zero() + local mask = alpha:clone() mask[torch.gt(mask, 0.0)] = 1 mask[torch.eq(mask, 0.0)] = 0 @@ -26,10 +26,10 @@ function alpha_util.make_border(rgb, alpha, offset) rgb[3][mask_nega] = 0 for i = 1, offset do - local mask_weight = alpha_util.sum2d:forward(mask:cuda()):float() + local mask_weight = sum2d:forward(mask:cuda()):float() local border = rgb:clone() for j = 1, 3 do - border[j]:copy(alpha_util.sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda())) + border[j]:copy(sum2d:forward(rgb[j]:reshape(1, rgb:size(2), rgb:size(3)):cuda())) border[j]:cdiv((mask_weight + eps)) rgb[j][mask_nega] = border[j][mask_nega] end