diff --git a/lib/data_augmentation.lua b/lib/data_augmentation.lua index 7bdb93e..65a50c8 100644 --- a/lib/data_augmentation.lua +++ b/lib/data_augmentation.lua @@ -13,6 +13,34 @@ local function pcacov(x) local ce, cv = torch.symeig(c, 'V') return ce, cv end +function data_augmentation.erase(src, p, n, rect_min, rect_max) + if torch.uniform() < p then + local src, conversion = iproc.byte2float(src) + src = src:contiguous():clone() + local ch = src:size(1) + local height = src:size(2) + local width = src:size(3) + for i = 1, n do + local r = torch.Tensor(4):uniform():cmul(torch.Tensor({height-1, width-1, rect_max - rect_min, rect_max - rect_min})):int() + local rect_y1 = r[1] + 1 + local rect_x1 = r[2] + 1 + local rect_h = r[3] + rect_min + local rect_w = r[4] + rect_min + local rect_x2 = math.min(rect_x1 + rect_w, width) + local rect_y2 = math.min(rect_y1 + rect_h, height) + local sub_rect = src:sub(1, ch, rect_y1, rect_y2, rect_x1, rect_x2) + for i = 1, ch do + sub_rect[i]:fill(src[i][rect_y1][rect_x1]) + end + end + if conversion then + src = iproc.float2byte(src) + end + return src + else + return src + end +end function data_augmentation.color_noise(src, p, factor) factor = factor or 0.1 if torch.uniform() < p then diff --git a/lib/pairwise_transform_utils.lua b/lib/pairwise_transform_utils.lua index 894efb2..199ab79 100644 --- a/lib/pairwise_transform_utils.lua +++ b/lib/pairwise_transform_utils.lua @@ -105,7 +105,11 @@ function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options) scale_max) x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate) x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate) - + x = data_augmentation.erase(x, + options.random_erasing_rate, + options.random_erasing_n, + options.random_erasing_rect_min, + options.random_erasing_rect_max) x = iproc.crop_mod4(x) y = iproc.crop_mod4(y) return x, y diff --git a/lib/settings.lua b/lib/settings.lua index 846a0f3..63e4699 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -44,6 +44,10 @@ cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairw cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate') cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method') cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method') +cmd:option("-random_erasing_rate", 0.0, 'data augmentation using random erasing for user method') +cmd:option("-random_erasing_n", 1, 'number of erasing') +cmd:option("-random_erasing_rect_min", 8, 'rect min size') +cmd:option("-random_erasing_rect_max", 32, 'rect max size') cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)') cmd:option("-pairwise_flip", 1, 'use flip(0|1)') cmd:option("-scale", 2.0, 'scale factor (2)') diff --git a/train.lua b/train.lua index 21cf08b..b2800a3 100644 --- a/train.lua +++ b/train.lua @@ -215,6 +215,17 @@ local function transform_pool_init(has_resize, offset) settings.crop_size, offset, n, conf) elseif settings.method == "user" then + local random_erasing_rate = 0 + local random_erasing_n = 0 + local random_erasing_rect_min = 0 + local random_erasing_rect_max = 0 + if is_validation then + else + random_erasing_rate = settings.random_erasing_rate + random_erasing_n = settings.random_erasing_n + random_erasing_rect_min = settings.random_erasing_rect_min + random_erasing_rect_max = settings.random_erasing_rect_max + end local conf = tablex.update({ gcn = settings.gcn, max_size = settings.max_size, @@ -230,6 +241,10 @@ local function transform_pool_init(has_resize, offset) random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate, pairwise_y_binary = settings.pairwise_y_binary, pairwise_flip = settings.pairwise_flip, + random_erasing_rate = random_erasing_rate, + random_erasing_n = random_erasing_n, + random_erasing_rect_min = random_erasing_rect_min, + random_erasing_rect_max = random_erasing_rect_max, rgb = (settings.color == "rgb")}, meta) return pairwise_transform.user(x, y, settings.crop_size, offset,