Add data augmentation for user method
This commit is contained in:
parent
a72af8cfef
commit
b066761cdc
|
@ -96,6 +96,49 @@ function data_augmentation.blur(src, p, size, sigma_min, sigma_max)
|
||||||
return src
|
return src
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max)
|
||||||
|
if torch.uniform() < p then
|
||||||
|
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||||
|
local scale = torch.uniform(scale_min, scale_max)
|
||||||
|
local h = math.floor(x:size(2) * scale)
|
||||||
|
local w = math.floor(x:size(3) * scale)
|
||||||
|
x = iproc.scale(x, w, h, "Triangle")
|
||||||
|
y = iproc.scale(y, w, h, "Triangle")
|
||||||
|
return x, y
|
||||||
|
else
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max)
|
||||||
|
if torch.uniform() < p then
|
||||||
|
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||||
|
local r = torch.uniform(r_min, r_max) / 360.0 * math.pi
|
||||||
|
x = iproc.rotate(x, r)
|
||||||
|
y = iproc.rotate(y, r)
|
||||||
|
return x, y
|
||||||
|
else
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function data_augmentation.pairwise_negate(x, y, p)
|
||||||
|
if torch.uniform() < p then
|
||||||
|
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||||
|
x = iproc.negate(x, r)
|
||||||
|
y = iproc.rotate(y, r)
|
||||||
|
return x, y
|
||||||
|
else
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function data_augmentation.pairwise_negate_x(x, y, p)
|
||||||
|
if torch.uniform() < p then
|
||||||
|
assert(x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||||
|
x = iproc.negate(x, r)
|
||||||
|
return x, y
|
||||||
|
else
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
|
end
|
||||||
function data_augmentation.shift_1px(src)
|
function data_augmentation.shift_1px(src)
|
||||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||||
local direction = torch.random(1, 4)
|
local direction = torch.random(1, 4)
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
local gm = {}
|
local gm = {}
|
||||||
gm.Image = require 'graphicsmagick.Image'
|
gm.Image = require 'graphicsmagick.Image'
|
||||||
local image = nil
|
|
||||||
require 'dok'
|
require 'dok'
|
||||||
require 'image'
|
local image = require 'image'
|
||||||
local iproc = {}
|
local iproc = {}
|
||||||
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||||
|
|
||||||
|
@ -158,6 +157,47 @@ function iproc.vflip(src)
|
||||||
local im = gm.Image(src, color, "DHW")
|
local im = gm.Image(src, color, "DHW")
|
||||||
return im:flip():toTensor(t, color, "DHW")
|
return im:flip():toTensor(t, color, "DHW")
|
||||||
end
|
end
|
||||||
|
local function rotate_with_warp(src, dst, theta, mode)
|
||||||
|
local height
|
||||||
|
local width
|
||||||
|
if src:dim() == 2 then
|
||||||
|
height = src:size(1)
|
||||||
|
width = src:size(2)
|
||||||
|
elseif src:dim() == 3 then
|
||||||
|
height = src:size(2)
|
||||||
|
width = src:size(3)
|
||||||
|
else
|
||||||
|
dok.error('src image must be 2D or 3D', 'image.rotate')
|
||||||
|
end
|
||||||
|
local flow = torch.Tensor(2, height, width)
|
||||||
|
local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)},
|
||||||
|
{math.sin(-theta), math.cos(-theta)}})
|
||||||
|
flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width))
|
||||||
|
flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5))
|
||||||
|
flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width))
|
||||||
|
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
|
||||||
|
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
|
||||||
|
dst:resizeAs(src)
|
||||||
|
return image.warp(dst, src, flow, mode, true, 'pad')
|
||||||
|
end
|
||||||
|
function iproc.rotate(src, theta)
|
||||||
|
local conversion
|
||||||
|
src, conversion = iproc.byte2float(src)
|
||||||
|
local dest = torch.Tensor():typeAs(src):resizeAs(src)
|
||||||
|
rotate_with_warp(src, dest, theta, 'bicubic')
|
||||||
|
dest:clamp(0, 1)
|
||||||
|
if conversion then
|
||||||
|
dest = iproc.float2byte(dest)
|
||||||
|
end
|
||||||
|
return dest
|
||||||
|
end
|
||||||
|
function iproc.negate(src)
|
||||||
|
if src:type() == "torch.ByteTensor" then
|
||||||
|
return -src + 255
|
||||||
|
else
|
||||||
|
return -src + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function iproc.gaussian2d(kernel_size, sigma)
|
function iproc.gaussian2d(kernel_size, sigma)
|
||||||
sigma = sigma or 1
|
sigma = sigma or 1
|
||||||
|
|
|
@ -4,37 +4,13 @@ local gm = {}
|
||||||
gm.Image = require 'graphicsmagick.Image'
|
gm.Image = require 'graphicsmagick.Image'
|
||||||
local pairwise_transform = {}
|
local pairwise_transform = {}
|
||||||
|
|
||||||
local function crop_if_large(x, y, scale_y, max_size, mod)
|
|
||||||
local tries = 4
|
|
||||||
if y:size(2) > max_size and y:size(3) > max_size then
|
|
||||||
assert(max_size % 4 == 0)
|
|
||||||
local rect_x, rect_y
|
|
||||||
for i = 1, tries do
|
|
||||||
local yi = torch.random(0, y:size(2) - max_size)
|
|
||||||
local xi = torch.random(0, y:size(3) - max_size)
|
|
||||||
if mod then
|
|
||||||
yi = yi - (yi % mod)
|
|
||||||
xi = xi - (xi % mod)
|
|
||||||
end
|
|
||||||
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
|
|
||||||
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
|
|
||||||
-- ignore simple background
|
|
||||||
if rect_y:float():std() >= 0 then
|
|
||||||
break
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return rect_x, rect_y
|
|
||||||
else
|
|
||||||
return x, y
|
|
||||||
end
|
|
||||||
end
|
|
||||||
function pairwise_transform.user(x, y, size, offset, n, options)
|
function pairwise_transform.user(x, y, size, offset, n, options)
|
||||||
assert(x:size(1) == y:size(1))
|
assert(x:size(1) == y:size(1))
|
||||||
|
|
||||||
local scale_y = y:size(2) / x:size(2)
|
local scale_y = y:size(2) / x:size(2)
|
||||||
assert(x:size(3) == y:size(3) / scale_y)
|
assert(x:size(3) == y:size(3) / scale_y)
|
||||||
|
|
||||||
x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y)
|
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
|
||||||
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
|
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
|
||||||
local batch = {}
|
local batch = {}
|
||||||
local lowres_y = pairwise_utils.low_resolution(y)
|
local lowres_y = pairwise_utils.low_resolution(y)
|
||||||
|
|
|
@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod)
|
||||||
return src
|
return src
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
|
||||||
|
local tries = 4
|
||||||
|
if y:size(2) > max_size and y:size(3) > max_size then
|
||||||
|
assert(max_size % 4 == 0)
|
||||||
|
local rect_x, rect_y
|
||||||
|
for i = 1, tries do
|
||||||
|
local yi = torch.random(0, y:size(2) - max_size)
|
||||||
|
local xi = torch.random(0, y:size(3) - max_size)
|
||||||
|
if mod then
|
||||||
|
yi = yi - (yi % mod)
|
||||||
|
xi = xi - (xi % mod)
|
||||||
|
end
|
||||||
|
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
|
||||||
|
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
|
||||||
|
-- ignore simple background
|
||||||
|
if rect_y:float():std() >= 0 then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return rect_x, rect_y
|
||||||
|
else
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
|
end
|
||||||
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
||||||
local dest = src
|
local dest = src
|
||||||
local box_only = false
|
local box_only = false
|
||||||
|
@ -65,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options)
|
||||||
end
|
end
|
||||||
return dest
|
return dest
|
||||||
end
|
end
|
||||||
|
function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
|
||||||
|
|
||||||
|
x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
|
||||||
|
x, y = data_augmentation.pairwise_rotate(x, y,
|
||||||
|
options.random_pairwise_rotate_rate,
|
||||||
|
options.random_pairwise_rotate_min,
|
||||||
|
options.random_pairwise_rotate_max)
|
||||||
|
|
||||||
|
local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
|
||||||
|
local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
|
||||||
|
x, y = data_augmentation.pairwise_scale(x, y,
|
||||||
|
options.random_pairwise_scale_rate,
|
||||||
|
scale_min,
|
||||||
|
scale_max)
|
||||||
|
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
|
||||||
|
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
|
||||||
|
|
||||||
|
x = iproc.crop_mod4(x)
|
||||||
|
y = iproc.crop_mod4(y)
|
||||||
|
|
||||||
|
if options.pairwise_y_binary then
|
||||||
|
y[torch.lt(y, 128)] = 0
|
||||||
|
y[torch.gt(y, 0)] = 255
|
||||||
|
end
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
||||||
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
||||||
assert("crop_size % scale == 0", size % scale == 0)
|
assert("crop_size % scale == 0", size % scale == 0)
|
||||||
|
|
|
@ -37,6 +37,15 @@ cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0
|
||||||
cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
|
cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)')
|
||||||
cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
|
cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur')
|
||||||
cmd:option("-random_blur_sigma_max", 0.75, 'max sigma for random gaussian blur')
|
cmd:option("-random_blur_sigma_max", 0.75, 'max sigma for random gaussian blur')
|
||||||
|
cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method')
|
||||||
|
cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale')
|
||||||
|
cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale')
|
||||||
|
cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method')
|
||||||
|
cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate')
|
||||||
|
cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
|
||||||
|
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
||||||
|
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
||||||
|
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
||||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||||
cmd:option("-crop_size", 48, 'crop size')
|
cmd:option("-crop_size", 48, 'crop size')
|
||||||
|
@ -81,6 +90,7 @@ end
|
||||||
to_bool(settings, "plot")
|
to_bool(settings, "plot")
|
||||||
to_bool(settings, "save_history")
|
to_bool(settings, "save_history")
|
||||||
to_bool(settings, "use_transparent_png")
|
to_bool(settings, "use_transparent_png")
|
||||||
|
to_bool(settings, "pairwise_y_binary")
|
||||||
|
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
require 'gnuplot'
|
require 'gnuplot'
|
||||||
|
|
16
train.lua
16
train.lua
|
@ -179,6 +179,15 @@ local function transform_pool_init(has_resize, offset)
|
||||||
max_size = settings.max_size,
|
max_size = settings.max_size,
|
||||||
active_cropping_rate = active_cropping_rate,
|
active_cropping_rate = active_cropping_rate,
|
||||||
active_cropping_tries = active_cropping_tries,
|
active_cropping_tries = active_cropping_tries,
|
||||||
|
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
|
||||||
|
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
||||||
|
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
||||||
|
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
|
||||||
|
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
||||||
|
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
||||||
|
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
|
||||||
|
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
||||||
|
pairwise_y_binary = settings.pairwise_y_binary,
|
||||||
rgb = (settings.color == "rgb")}, meta)
|
rgb = (settings.color == "rgb")}, meta)
|
||||||
return pairwise_transform.user(x, y,
|
return pairwise_transform.user(x, y,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
|
@ -393,6 +402,13 @@ local function train()
|
||||||
else
|
else
|
||||||
model = srcnn.create(settings.model, settings.backend, settings.color)
|
model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||||
end
|
end
|
||||||
|
if model.w2nn_input_size then
|
||||||
|
if settings.crop_size ~= model.w2nn_input_size then
|
||||||
|
io.stderr:write(string.format("warning: crop_size is replaced with %d\n",
|
||||||
|
model.w2nn_input_size))
|
||||||
|
settings.crop_size = model.w2nn_input_size
|
||||||
|
end
|
||||||
|
end
|
||||||
dir.makepath(settings.model_dir)
|
dir.makepath(settings.model_dir)
|
||||||
|
|
||||||
local offset = reconstruct.offset_size(model)
|
local offset = reconstruct.offset_size(model)
|
||||||
|
|
Loading…
Reference in a new issue