Add random blur method for data augmentation
This commit is contained in:
parent
cabeeed2a7
commit
5a3d012f4e
|
@ -1,3 +1,5 @@
|
|||
require 'pl'
|
||||
require 'cunn'
|
||||
local iproc = require 'iproc'
|
||||
local gm = {}
|
||||
gm.Image = require 'graphicsmagick.Image'
|
||||
|
@ -69,6 +71,41 @@ function data_augmentation.unsharp_mask(src, p)
|
|||
return src
|
||||
end
|
||||
end
|
||||
data_augmentation.blur_conv = {}
|
||||
function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
|
||||
size = size or "3"
|
||||
filters = utils.split(size, ",")
|
||||
for i = 1, #filters do
|
||||
local s = tonumber(filters[i])
|
||||
filters[i] = s
|
||||
if not data_augmentation.blur_conv[s] then
|
||||
data_augmentation.blur_conv[s] = nn.SpatialConvolutionMM(1, 1, s, s, 1, 1, (s - 1) / 2, (s - 1) / 2):noBias():cuda()
|
||||
end
|
||||
end
|
||||
if torch.uniform() < p then
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
local kernel_size = filters[torch.random(1, #filters)]
|
||||
local sigma
|
||||
if sigma_min == sigma_max then
|
||||
sigma = sigma_min
|
||||
else
|
||||
sigma = torch.uniform(sigma_min, sigma_max)
|
||||
end
|
||||
local kernel = iproc.gaussian2d(kernel_size, sigma)
|
||||
data_augmentation.blur_conv[kernel_size].weight:copy(kernel)
|
||||
local dest = torch.Tensor(3, src:size(2), src:size(3))
|
||||
dest[1]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[1]:reshape(1, src:size(2), src:size(3)):cuda()))
|
||||
dest[2]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[2]:reshape(1, src:size(2), src:size(3)):cuda()))
|
||||
dest[3]:copy(data_augmentation.blur_conv[kernel_size]:forward(src[3]:reshape(1, src:size(2), src:size(3)):cuda()))
|
||||
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||
local direction = torch.random(1, 4)
|
||||
|
@ -119,4 +156,20 @@ function data_augmentation.flip(src)
|
|||
end
|
||||
return dest
|
||||
end
|
||||
|
||||
local function test_blur()
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local image =require 'image'
|
||||
local src = image.lena()
|
||||
|
||||
image.display({image = src, min=0, max=1})
|
||||
local dest = data_augmentation.blur(src, 1.0, "3,5", 0.5, 0.6)
|
||||
image.display({image = dest, min=0, max=1})
|
||||
dest = data_augmentation.blur(src, 1.0, "3", 1.0, 1.0)
|
||||
image.display({image = dest, min=0, max=1})
|
||||
dest = data_augmentation.blur(src, 1.0, "5", 0.75, 0.75)
|
||||
image.display({image = dest, min=0, max=1})
|
||||
end
|
||||
--test_blur()
|
||||
|
||||
return data_augmentation
|
||||
|
|
|
@ -254,7 +254,19 @@ function iproc.yuv2rgb(...)
|
|||
-- return RGB image
|
||||
return output
|
||||
end
|
||||
|
||||
function iproc.gaussian2d(kernel_size, sigma)
|
||||
sigma = sigma or 1
|
||||
local kernel = torch.Tensor(kernel_size, kernel_size)
|
||||
local u = math.floor(kernel_size / 2) + 1
|
||||
local amp = (1 / math.sqrt(2 * math.pi * sigma^2))
|
||||
for x = 1, kernel_size do
|
||||
for y = 1, kernel_size do
|
||||
kernel[x][y] = amp * math.exp(-((x - u)^2 + (y - u)^2) / (2 * sigma^2))
|
||||
end
|
||||
end
|
||||
kernel:div(kernel:sum())
|
||||
return kernel
|
||||
end
|
||||
local function test_conversion()
|
||||
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
||||
local b = iproc.float2byte(a)
|
||||
|
@ -286,9 +298,17 @@ local function test_flip()
|
|||
print((image.vflip(src) - iproc.vflip(src)):sum())
|
||||
print((image.vflip(src_byte) - iproc.vflip(src_byte)):sum())
|
||||
end
|
||||
local function test_gaussian2d()
|
||||
local t = {3, 5, 7}
|
||||
for i = 1, #t do
|
||||
local kp = iproc.gaussian2d(t[i], 0.5)
|
||||
print(kp)
|
||||
end
|
||||
end
|
||||
|
||||
--test_conversion()
|
||||
--test_flip()
|
||||
--test_gaussian2d()
|
||||
|
||||
return iproc
|
||||
|
||||
|
|
|
@ -56,6 +56,10 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|||
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.blur(dest, options.random_blur_rate,
|
||||
options.random_blur_size,
|
||||
options.random_blur_min,
|
||||
options.random_blur_max)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
|
|
|
@ -32,6 +32,10 @@ cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise
|
|||
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
||||
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
||||
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
|
||||
cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0-1.0)')
|
||||
cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
|
||||
cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
|
||||
cmd:option("-random_blur_sigma_max", 0.75, 'max sigma for random gaussian blur')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-crop_size", 48, 'crop size')
|
||||
|
|
12
train.lua
12
train.lua
|
@ -97,6 +97,10 @@ local function transform_pool_init(has_resize, offset)
|
|||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
random_blur_rate = settings.random_blur_rate,
|
||||
random_blur_size = settings.random_blur_size,
|
||||
random_blur_sigma_min = settings.random_blur_sigma_min,
|
||||
random_blur_sigma_max = settings.random_blur_sigma_max,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
|
@ -114,6 +118,10 @@ local function transform_pool_init(has_resize, offset)
|
|||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
random_blur_rate = settings.random_blur_rate,
|
||||
random_blur_size = settings.random_blur_size,
|
||||
random_blur_sigma_min = settings.random_blur_sigma_min,
|
||||
random_blur_sigma_max = settings.random_blur_sigma_max,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
|
@ -132,6 +140,10 @@ local function transform_pool_init(has_resize, offset)
|
|||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
random_blur_rate = settings.random_blur_rate,
|
||||
random_blur_size = settings.random_blur_size,
|
||||
random_blur_sigma_min = settings.random_blur_sigma_min,
|
||||
random_blur_sigma_max = settings.random_blur_sigma_max,
|
||||
max_size = settings.max_size,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
nr_rate = settings.nr_rate,
|
||||
|
|
Loading…
Reference in a new issue