Use conv2d instead of nn.SpatialConvolutionMM
This commit is contained in:
parent
a14e6acec3
commit
f65132dadb
|
@ -71,16 +71,12 @@ function data_augmentation.unsharp_mask(src, p)
|
|||
return src
|
||||
end
|
||||
end
|
||||
data_augmentation.blur_conv = {}
|
||||
function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
|
||||
size = size or "3"
|
||||
filters = utils.split(size, ",")
|
||||
for i = 1, #filters do
|
||||
local s = tonumber(filters[i])
|
||||
filters[i] = s
|
||||
if not data_augmentation.blur_conv[s] then
|
||||
data_augmentation.blur_conv[s] = nn.SpatialConvolutionMM(1, 1, s, s, 1, 1, (s - 1) / 2, (s - 1) / 2):noBias():cuda()
|
||||
end
|
||||
end
|
||||
if torch.uniform() < p then
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
|
@ -92,12 +88,7 @@ function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
|
|||
sigma = torch.uniform(sigma_min, sigma_max)
|
||||
end
|
||||
local kernel = iproc.gaussian2d(kernel_size, sigma)
|
||||
data_augmentation.blur_conv[kernel_size].weight:copy(kernel)
|
||||
local dest = torch.Tensor(3, src:size(2), src:size(3))
|
||||
dest[1]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[1]:reshape(1, src:size(2), src:size(3)):cuda()))
|
||||
dest[2]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[2]:reshape(1, src:size(2), src:size(3)):cuda()))
|
||||
dest[3]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[3]:reshape(1, src:size(2), src:size(3)):cuda()))
|
||||
|
||||
local dest = iproc.convolve(src, kernel, 'same')
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
local image = nil
|
||||
require 'dok'
|
||||
|
||||
local iproc = {}
|
||||
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||
|
@ -267,6 +268,67 @@ function iproc.gaussian2d(kernel_size, sigma)
|
|||
kernel:div(kernel:sum())
|
||||
return kernel
|
||||
end
|
||||
|
||||
-- from image.convolve
|
||||
function iproc.convolve(...)
|
||||
local dst,src,kernel,mode
|
||||
local args = {...}
|
||||
if select('#',...) == 4 then
|
||||
dst = args[1]
|
||||
src = args[2]
|
||||
kernel = args[3]
|
||||
mode = args[4]
|
||||
elseif select('#',...) == 3 then
|
||||
if type(args[3]) == 'string' then
|
||||
src = args[1]
|
||||
kernel = args[2]
|
||||
mode = args[3]
|
||||
else
|
||||
dst = args[1]
|
||||
src = args[2]
|
||||
kernel = args[3]
|
||||
end
|
||||
elseif select('#',...) == 2 then
|
||||
src = args[1]
|
||||
kernel = args[2]
|
||||
else
|
||||
print(dok.usage('iproc.convolve',
|
||||
'convolves an input image with a kernel, returns the result', nil,
|
||||
{type='torch.Tensor', help='input image', req=true},
|
||||
{type='torch.Tensor', help='kernel', req=true},
|
||||
{type='string', help='type: full | valid | same', default='valid'},
|
||||
'',
|
||||
{type='torch.Tensor', help='destination', req=true},
|
||||
{type='torch.Tensor', help='input image', req=true},
|
||||
{type='torch.Tensor', help='kernel', req=true},
|
||||
{type='string', help='type: full | valid | same', default='valid'}))
|
||||
dok.error('incorrect arguments', 'image.convolve')
|
||||
end
|
||||
if mode and mode ~= 'valid' and mode ~= 'full' and mode ~= 'same' then
|
||||
dok.error('mode has to be one of: full | valid | same', 'image.convolve')
|
||||
end
|
||||
local md = (((mode == 'full') or (mode == 'same')) and 'F') or 'V'
|
||||
if kernel:nDimension() == 2 and src:nDimension() == 3 then
|
||||
local k3d = src.new(src:size(1), kernel:size(1), kernel:size(2))
|
||||
for i = 1,src:size(1) do
|
||||
k3d[i]:copy(kernel)
|
||||
end
|
||||
kernel = k3d
|
||||
end
|
||||
if dst then
|
||||
torch.conv2(dst,src,kernel,md)
|
||||
else
|
||||
dst = torch.conv2(src,kernel,md)
|
||||
end
|
||||
if mode == 'same' then
|
||||
local cx = dst:dim()
|
||||
local cy = cx-1
|
||||
local ofy = math.ceil(kernel:size(cy)/2)
|
||||
local ofx = math.ceil(kernel:size(cx)/2)
|
||||
dst = dst:narrow(cy, ofy, src:size(cy)):narrow(cx, ofx, src:size(cx))
|
||||
end
|
||||
return dst
|
||||
end
|
||||
local function test_conversion()
|
||||
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
||||
local b = iproc.float2byte(a)
|
||||
|
|
Loading…
Reference in a new issue