merge random erasing
This commit is contained in:
parent
9f4626df80
commit
93cd40a53c
|
@ -13,6 +13,34 @@ local function pcacov(x)
|
||||||
local ce, cv = torch.symeig(c, 'V')
|
local ce, cv = torch.symeig(c, 'V')
|
||||||
return ce, cv
|
return ce, cv
|
||||||
end
|
end
|
||||||
|
function data_augmentation.erase(src, p, n, rect_min, rect_max)
|
||||||
|
if torch.uniform() < p then
|
||||||
|
local src, conversion = iproc.byte2float(src)
|
||||||
|
src = src:contiguous():clone()
|
||||||
|
local ch = src:size(1)
|
||||||
|
local height = src:size(2)
|
||||||
|
local width = src:size(3)
|
||||||
|
for i = 1, n do
|
||||||
|
local r = torch.Tensor(4):uniform():cmul(torch.Tensor({height-1, width-1, rect_max - rect_min, rect_max - rect_min})):int()
|
||||||
|
local rect_y1 = r[1] + 1
|
||||||
|
local rect_x1 = r[2] + 1
|
||||||
|
local rect_h = r[3] + rect_min
|
||||||
|
local rect_w = r[4] + rect_min
|
||||||
|
local rect_x2 = math.min(rect_x1 + rect_w, width)
|
||||||
|
local rect_y2 = math.min(rect_y1 + rect_h, height)
|
||||||
|
local sub_rect = src:sub(1, ch, rect_y1, rect_y2, rect_x1, rect_x2)
|
||||||
|
for i = 1, ch do
|
||||||
|
sub_rect[i]:fill(src[i][rect_y1][rect_x1])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
if conversion then
|
||||||
|
src = iproc.float2byte(src)
|
||||||
|
end
|
||||||
|
return src
|
||||||
|
else
|
||||||
|
return src
|
||||||
|
end
|
||||||
|
end
|
||||||
function data_augmentation.color_noise(src, p, factor)
|
function data_augmentation.color_noise(src, p, factor)
|
||||||
factor = factor or 0.1
|
factor = factor or 0.1
|
||||||
if torch.uniform() < p then
|
if torch.uniform() < p then
|
||||||
|
|
|
@ -105,7 +105,11 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
|
||||||
scale_max)
|
scale_max)
|
||||||
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
|
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
|
||||||
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
|
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
|
||||||
|
x = data_augmentation.erase(x,
|
||||||
|
options.random_erasing_rate,
|
||||||
|
options.random_erasing_n,
|
||||||
|
options.random_erasing_rect_min,
|
||||||
|
options.random_erasing_rect_max)
|
||||||
x = iproc.crop_mod4(x)
|
x = iproc.crop_mod4(x)
|
||||||
y = iproc.crop_mod4(y)
|
y = iproc.crop_mod4(y)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
|
@ -44,6 +44,10 @@ cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairw
|
||||||
cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
|
cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate')
|
||||||
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
||||||
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
||||||
|
cmd:option("-random_erasing_rate", 0.0, 'data augmentation using random erasing for user method')
|
||||||
|
cmd:option("-random_erasing_n", 1, 'number of erasing')
|
||||||
|
cmd:option("-random_erasing_rect_min", 8, 'rect min size')
|
||||||
|
cmd:option("-random_erasing_rect_max", 32, 'rect max size')
|
||||||
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
||||||
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
|
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
|
||||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||||
|
|
15
train.lua
15
train.lua
|
@ -215,6 +215,17 @@ local function transform_pool_init(has_resize, offset)
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
n, conf)
|
n, conf)
|
||||||
elseif settings.method == "user" then
|
elseif settings.method == "user" then
|
||||||
|
local random_erasing_rate = 0
|
||||||
|
local random_erasing_n = 0
|
||||||
|
local random_erasing_rect_min = 0
|
||||||
|
local random_erasing_rect_max = 0
|
||||||
|
if is_validation then
|
||||||
|
else
|
||||||
|
random_erasing_rate = settings.random_erasing_rate
|
||||||
|
random_erasing_n = settings.random_erasing_n
|
||||||
|
random_erasing_rect_min = settings.random_erasing_rect_min
|
||||||
|
random_erasing_rect_max = settings.random_erasing_rect_max
|
||||||
|
end
|
||||||
local conf = tablex.update({
|
local conf = tablex.update({
|
||||||
gcn = settings.gcn,
|
gcn = settings.gcn,
|
||||||
max_size = settings.max_size,
|
max_size = settings.max_size,
|
||||||
|
@ -230,6 +241,10 @@ local function transform_pool_init(has_resize, offset)
|
||||||
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
||||||
pairwise_y_binary = settings.pairwise_y_binary,
|
pairwise_y_binary = settings.pairwise_y_binary,
|
||||||
pairwise_flip = settings.pairwise_flip,
|
pairwise_flip = settings.pairwise_flip,
|
||||||
|
random_erasing_rate = random_erasing_rate,
|
||||||
|
random_erasing_n = random_erasing_n,
|
||||||
|
random_erasing_rect_min = random_erasing_rect_min,
|
||||||
|
random_erasing_rect_max = random_erasing_rect_max,
|
||||||
rgb = (settings.color == "rgb")}, meta)
|
rgb = (settings.color == "rgb")}, meta)
|
||||||
return pairwise_transform.user(x, y,
|
return pairwise_transform.user(x, y,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
|
|
Loading…
Reference in a new issue