diff --git a/lib/data_augmentation.lua b/lib/data_augmentation.lua index ce73bfe..3c76b5f 100644 --- a/lib/data_augmentation.lua +++ b/lib/data_augmentation.lua @@ -96,6 +96,49 @@ function data_augmentation.blur(src, p, size, sigma_min, sigma_max) return src end end +function data_augmentation.pairwise_scale(x, y, p, scale_min, scale_max) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + local scale = torch.uniform(scale_min, scale_max) + local h = math.floor(x:size(2) * scale) + local w = math.floor(x:size(3) * scale) + x = iproc.scale(x, w, h, "Triangle") + y = iproc.scale(y, w, h, "Triangle") + return x, y + else + return x, y + end +end +function data_augmentation.pairwise_rotate(x, y, p, r_min, r_max) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + local r = torch.uniform(r_min, r_max) / 360.0 * math.pi + x = iproc.rotate(x, r) + y = iproc.rotate(y, r) + return x, y + else + return x, y + end +end +function data_augmentation.pairwise_negate(x, y, p) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + x = iproc.negate(x, r) + y = iproc.rotate(y, r) + return x, y + else + return x, y + end +end +function data_augmentation.pairwise_negate_x(x, y, p) + if torch.uniform() < p then + assert(x:size(2) == y:size(2) and x:size(3) == y:size(3)) + x = iproc.negate(x, r) + return x, y + else + return x, y + end +end function data_augmentation.shift_1px(src) -- reducing the even/odd issue in nearest neighbor scaler. local direction = torch.random(1, 4) diff --git a/lib/iproc.lua b/lib/iproc.lua index 8fddfb6..9de3eea 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -1,8 +1,7 @@ local gm = {} gm.Image = require 'graphicsmagick.Image' -local image = nil require 'dok' -require 'image' +local image = require 'image' local iproc = {} local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5) @@ -158,6 +157,47 @@ function iproc.vflip(src) local im = gm.Image(src, color, "DHW") return im:flip():toTensor(t, color, "DHW") end +local function rotate_with_warp(src, dst, theta, mode) + local height + local width + if src:dim() == 2 then + height = src:size(1) + width = src:size(2) + elseif src:dim() == 3 then + height = src:size(2) + width = src:size(3) + else + dok.error('src image must be 2D or 3D', 'image.rotate') + end + local flow = torch.Tensor(2, height, width) + local kernel = torch.Tensor({{math.cos(-theta), -math.sin(-theta)}, + {math.sin(-theta), math.cos(-theta)}}) + flow[1] = torch.ger(torch.linspace(0, 1, height), torch.ones(width)) + flow[1]:mul(-(height -1)):add(math.floor(height / 2 + 0.5)) + flow[2] = torch.ger(torch.ones(height), torch.linspace(0, 1, width)) + flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5)) + flow:add(-1, torch.mm(kernel, flow:view(2, height * width))) + dst:resizeAs(src) + return image.warp(dst, src, flow, mode, true, 'pad') +end +function iproc.rotate(src, theta) + local conversion + src, conversion = iproc.byte2float(src) + local dest = torch.Tensor():typeAs(src):resizeAs(src) + rotate_with_warp(src, dest, theta, 'bicubic') + dest:clamp(0, 1) + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end +function iproc.negate(src) + if src:type() == "torch.ByteTensor" then + return -src + 255 + else + return -src + 1 + end +end function iproc.gaussian2d(kernel_size, sigma) sigma = sigma or 1 diff --git a/lib/pairwise_transform_user.lua b/lib/pairwise_transform_user.lua index fb3ff1f..8493a41 100644 --- a/lib/pairwise_transform_user.lua +++ b/lib/pairwise_transform_user.lua @@ -4,37 +4,13 @@ local gm = {} gm.Image = require 'graphicsmagick.Image' local pairwise_transform = {} -local function crop_if_large(x, y, scale_y, max_size, mod) - local tries = 4 - if y:size(2) > max_size and y:size(3) > max_size then - assert(max_size % 4 == 0) - local rect_x, rect_y - for i = 1, tries do - local yi = torch.random(0, y:size(2) - max_size) - local xi = torch.random(0, y:size(3) - max_size) - if mod then - yi = yi - (yi % mod) - xi = xi - (xi % mod) - end - rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size) - rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y) - -- ignore simple background - if rect_y:float():std() >= 0 then - break - end - end - return rect_x, rect_y - else - return x, y - end -end function pairwise_transform.user(x, y, size, offset, n, options) assert(x:size(1) == y:size(1)) local scale_y = y:size(2) / x:size(2) assert(x:size(3) == y:size(3) / scale_y) - x, y = crop_if_large(x, y, scale_y, options.max_size, scale_y) + x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options) assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y) local batch = {} local lowres_y = pairwise_utils.low_resolution(y) diff --git a/lib/pairwise_transform_utils.lua b/lib/pairwise_transform_utils.lua index 57508b1..6082d2e 100644 --- a/lib/pairwise_transform_utils.lua +++ b/lib/pairwise_transform_utils.lua @@ -36,6 +36,30 @@ function pairwise_transform_utils.crop_if_large(src, max_size, mod) return src end end +function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod) + local tries = 4 + if y:size(2) > max_size and y:size(3) > max_size then + assert(max_size % 4 == 0) + local rect_x, rect_y + for i = 1, tries do + local yi = torch.random(0, y:size(2) - max_size) + local xi = torch.random(0, y:size(3) - max_size) + if mod then + yi = yi - (yi % mod) + xi = xi - (xi % mod) + end + rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size) + rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y) + -- ignore simple background + if rect_y:float():std() >= 0 then + break + end + end + return rect_x, rect_y + else + return x, y + end +end function pairwise_transform_utils.preprocess(src, crop_size, options) local dest = src local box_only = false @@ -65,6 +89,33 @@ function pairwise_transform_utils.preprocess(src, crop_size, options) end return dest end +function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options) + + x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y) + x, y = data_augmentation.pairwise_rotate(x, y, + options.random_pairwise_rotate_rate, + options.random_pairwise_rotate_min, + options.random_pairwise_rotate_max) + + local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3)))) + local scale_max = math.max(scale_min, options.random_pairwise_scale_max) + x, y = data_augmentation.pairwise_scale(x, y, + options.random_pairwise_scale_rate, + scale_min, + scale_max) + x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate) + x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate) + + x = iproc.crop_mod4(x) + y = iproc.crop_mod4(y) + + if options.pairwise_y_binary then + y[torch.lt(y, 128)] = 0 + y[torch.gt(y, 0)] = 255 + end + + return x, y +end function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries) assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3)) assert("crop_size % scale == 0", size % scale == 0) diff --git a/lib/settings.lua b/lib/settings.lua index c150410..accec5e 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -37,6 +37,15 @@ cmd:option("-random_blur_rate", 0.0, 'data augmentation using gaussian blur (0.0 cmd:option("-random_blur_size", "3,5", 'filter size for random gaussian blur (comma separated)') cmd:option("-random_blur_sigma_min", 0.5, 'min sigma for random gaussian blur') cmd:option("-random_blur_sigma_max", 0.75, 'max sigma for random gaussian blur') +cmd:option("-random_pairwise_scale_rate", 0.0, 'data augmentation using pairwise resize for user method') +cmd:option("-random_pairwise_scale_min", 0.85, 'min scale factor for random pairwise scale') +cmd:option("-random_pairwise_scale_max", 1.176, 'max scale factor for random pairwise scale') +cmd:option("-random_pairwise_rotate_rate", 0.0, 'data augmentation using pairwise resize for user method') +cmd:option("-random_pairwise_rotate_min", -6, 'min rotate angle for random pairwise rotate') +cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwise rotate') +cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method') +cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method') +cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)') cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-crop_size", 48, 'crop size') @@ -81,6 +90,7 @@ end to_bool(settings, "plot") to_bool(settings, "save_history") to_bool(settings, "use_transparent_png") +to_bool(settings, "pairwise_y_binary") if settings.plot then require 'gnuplot' diff --git a/train.lua b/train.lua index c0ea923..6c4821a 100644 --- a/train.lua +++ b/train.lua @@ -179,6 +179,15 @@ local function transform_pool_init(has_resize, offset) max_size = settings.max_size, active_cropping_rate = active_cropping_rate, active_cropping_tries = active_cropping_tries, + random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate, + random_pairwise_rotate_min = settings.random_pairwise_rotate_min, + random_pairwise_rotate_max = settings.random_pairwise_rotate_max, + random_pairwise_scale_rate = settings.random_pairwise_scale_rate, + random_pairwise_scale_min = settings.random_pairwise_scale_min, + random_pairwise_scale_max = settings.random_pairwise_scale_max, + random_pairwise_negate_rate = settings.random_pairwise_negate_rate, + random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate, + pairwise_y_binary = settings.pairwise_y_binary, rgb = (settings.color == "rgb")}, meta) return pairwise_transform.user(x, y, settings.crop_size, offset, @@ -393,6 +402,13 @@ local function train() else model = srcnn.create(settings.model, settings.backend, settings.color) end + if model.w2nn_input_size then + if settings.crop_size ~= model.w2nn_input_size then + io.stderr:write(string.format("warning: crop_size is replaced with %d\n", + model.w2nn_input_size)) + settings.crop_size = model.w2nn_input_size + end + end dir.makepath(settings.model_dir) local offset = reconstruct.offset_size(model)