1
0
Fork 0
mirror of synced 2024-06-01 10:39:30 +12:00

Use conv2d instead of nn.SpatialConvolutionMM

This commit is contained in:
nagadomi 2016-09-24 05:54:18 +09:00
parent a14e6acec3
commit f65132dadb
2 changed files with 63 additions and 10 deletions

View file

@ -71,16 +71,12 @@ function data_augmentation.unsharp_mask(src, p)
return src
end
end
data_augmentation.blur_conv = {}
function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
size = size or "3"
filters = utils.split(size, ",")
for i = 1, #filters do
local s = tonumber(filters[i])
filters[i] = s
if not data_augmentation.blur_conv[s] then
data_augmentation.blur_conv[s] = nn.SpatialConvolutionMM(1, 1, s, s, 1, 1, (s - 1) / 2, (s - 1) / 2):noBias():cuda()
end
end
if torch.uniform() < p then
local src, conversion = iproc.byte2float(src)
@ -92,12 +88,7 @@ function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
sigma = torch.uniform(sigma_min, sigma_max)
end
local kernel = iproc.gaussian2d(kernel_size, sigma)
data_augmentation.blur_conv[kernel_size].weight:copy(kernel)
local dest = torch.Tensor(3, src:size(2), src:size(3))
dest[1]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[1]:reshape(1, src:size(2), src:size(3)):cuda()))
dest[2]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[2]:reshape(1, src:size(2), src:size(3)):cuda()))
dest[3]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[3]:reshape(1, src:size(2), src:size(3)):cuda()))
local dest = iproc.convolve(src, kernel, 'same')
if conversion then
dest = iproc.float2byte(dest)
end

View file

@ -1,6 +1,7 @@
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local image = nil
require 'dok'
local iproc = {}
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
@ -267,6 +268,67 @@ function iproc.gaussian2d(kernel_size, sigma)
kernel:div(kernel:sum())
return kernel
end
-- from image.convolve
function iproc.convolve(...)
local dst,src,kernel,mode
local args = {...}
if select('#',...) == 4 then
dst = args[1]
src = args[2]
kernel = args[3]
mode = args[4]
elseif select('#',...) == 3 then
if type(args[3]) == 'string' then
src = args[1]
kernel = args[2]
mode = args[3]
else
dst = args[1]
src = args[2]
kernel = args[3]
end
elseif select('#',...) == 2 then
src = args[1]
kernel = args[2]
else
print(dok.usage('iproc.convolve',
'convolves an input image with a kernel, returns the result', nil,
{type='torch.Tensor', help='input image', req=true},
{type='torch.Tensor', help='kernel', req=true},
{type='string', help='type: full | valid | same', default='valid'},
'',
{type='torch.Tensor', help='destination', req=true},
{type='torch.Tensor', help='input image', req=true},
{type='torch.Tensor', help='kernel', req=true},
{type='string', help='type: full | valid | same', default='valid'}))
dok.error('incorrect arguments', 'image.convolve')
end
if mode and mode ~= 'valid' and mode ~= 'full' and mode ~= 'same' then
dok.error('mode has to be one of: full | valid | same', 'image.convolve')
end
local md = (((mode == 'full') or (mode == 'same')) and 'F') or 'V'
if kernel:nDimension() == 2 and src:nDimension() == 3 then
local k3d = src.new(src:size(1), kernel:size(1), kernel:size(2))
for i = 1,src:size(1) do
k3d[i]:copy(kernel)
end
kernel = k3d
end
if dst then
torch.conv2(dst,src,kernel,md)
else
dst = torch.conv2(src,kernel,md)
end
if mode == 'same' then
local cx = dst:dim()
local cy = cx-1
local ofy = math.ceil(kernel:size(cy)/2)
local ofx = math.ceil(kernel:size(cx)/2)
dst = dst:narrow(cy, ofy, src:size(cy)):narrow(cx, ofx, src:size(cx))
end
return dst
end
local function test_conversion()
local a = torch.linspace(0, 255, 256):float():div(255.0)
local b = iproc.float2byte(a)