diff --git a/lib/iproc.lua b/lib/iproc.lua index 5144222..b4e6e17 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -178,7 +178,7 @@ local function rotate_with_warp(src, dst, theta, mode) flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5)) flow:add(-1, torch.mm(kernel, flow:view(2, height * width))) dst:resizeAs(src) - return image.warp(dst, src, flow, mode, true, 'pad') + return image.warp(dst, src, flow, mode, true, 'clamp') end function iproc.rotate(src, theta) local conversion @@ -212,6 +212,16 @@ function iproc.gaussian2d(kernel_size, sigma) kernel:div(kernel:sum()) return kernel end +function iproc.rgb2y(src) + local conversion + src, conversion = iproc.byte2float(src) + local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero() + dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3]) + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end local function test_conversion() local a = torch.linspace(0, 255, 256):float():div(255.0) diff --git a/lib/pairwise_transform_user.lua b/lib/pairwise_transform_user.lua index 8493a41..d6a4f82 100644 --- a/lib/pairwise_transform_user.lua +++ b/lib/pairwise_transform_user.lua @@ -13,8 +13,18 @@ function pairwise_transform.user(x, y, size, offset, n, options) x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options) assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y) local batch = {} - local lowres_y = pairwise_utils.low_resolution(y) - local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y) + local lowres_y = nil + local xs ={x} + local ys = {y} + local ls = {} + + if options.active_cropping_rate > 0 then + lowres_y = pairwise_utils.low_resolution(y) + end + if options.pairwise_flip then + xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y) + end + assert(#xs == #ys) for i = 1, n do local t = (i % #xs) + 1 local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y, @@ -24,8 +34,8 @@ function pairwise_transform.user(x, y, size, offset, n, options) yc = iproc.byte2float(yc) if options.rgb then else - yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) - xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + yc = iproc.rgb2y(yc) + xc = iproc.rgb2y(xc) end table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end diff --git a/lib/pairwise_transform_utils.lua b/lib/pairwise_transform_utils.lua index 6082d2e..5938a86 100644 --- a/lib/pairwise_transform_utils.lua +++ b/lib/pairwise_transform_utils.lua @@ -164,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise) for j = 1, 2 do -- TTA - local xi, yi, ri + local xi, yi, ri, ni if j == 1 then xi = x ni = x_noise @@ -176,7 +176,9 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise) ni = x_noise:transpose(2, 3):contiguous() end yi = y:transpose(2, 3):contiguous() - ri = lowres_y:transpose(2, 3):contiguous() + if lowres_y then + ri = lowres_y:transpose(2, 3):contiguous() + end end local xv = iproc.vflip(xi) local nv @@ -184,34 +186,45 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise) nv = iproc.vflip(ni) end local yv = iproc.vflip(yi) - local rv = iproc.vflip(ri) + local rv + if ri then + rv = iproc.vflip(ri) + end table.insert(xs, xi) if ni then table.insert(ns, ni) end table.insert(ys, yi) - table.insert(ls, ri) + if ri then + table.insert(ls, ri) + end table.insert(xs, xv) if nv then table.insert(ns, nv) end table.insert(ys, yv) - table.insert(ls, rv) + if rv then + table.insert(ls, rv) + end table.insert(xs, iproc.hflip(xi)) if ni then table.insert(ns, iproc.hflip(ni)) end table.insert(ys, iproc.hflip(yi)) - table.insert(ls, iproc.hflip(ri)) + if ri then + table.insert(ls, iproc.hflip(ri)) + end table.insert(xs, iproc.hflip(xv)) if nv then table.insert(ns, iproc.hflip(nv)) end table.insert(ys, iproc.hflip(yv)) - table.insert(ls, iproc.hflip(rv)) + if rv then + table.insert(ls, iproc.hflip(rv)) + end end return xs, ys, ls, ns end diff --git a/lib/settings.lua b/lib/settings.lua index accec5e..41a4085 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -46,6 +46,7 @@ cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwi cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method') cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method') cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)') +cmd:option("-pairwise_flip", 1, 'use flip(0|1)') cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-crop_size", 48, 'crop size') @@ -91,6 +92,7 @@ to_bool(settings, "plot") to_bool(settings, "save_history") to_bool(settings, "use_transparent_png") to_bool(settings, "pairwise_y_binary") +to_bool(settings, "pairwise_flip") if settings.plot then require 'gnuplot' diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 2ac2a05..e7a49c8 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -466,9 +466,6 @@ function srcnn.fcn_v1(backend, ch) model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) model:add(nn.Dropout(0.5, false, true)) - model:add(SpatialConvolution(backend, 256, 256, 1, 1, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.Dropout(0.5, false, true)) model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) diff --git a/train.lua b/train.lua index 6c4821a..43777b9 100644 --- a/train.lua +++ b/train.lua @@ -175,19 +175,36 @@ local function transform_pool_init(has_resize, offset) settings.crop_size, offset, n, conf) elseif settings.method == "user" then + if is_validation == nil then is_validation = false end + local rotate_rate = nil + local scale_rate = nil + local negate_rate = nil + local negate_x_rate = nil + if is_validation then + rotate_rate = 0 + scale_rate = 0 + negate_rate = 0 + negate_x_rate = 0 + else + rotate_rate = settings.random_pairwise_rotate_rate + scale_rate = settings.random_pairwise_scale_rate + negate_rate = settings.random_pairwise_negate_rate + negate_x_rate = settings.random_pairwise_negate_x_rate + end local conf = tablex.update({ max_size = settings.max_size, active_cropping_rate = active_cropping_rate, active_cropping_tries = active_cropping_tries, - random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate, + random_pairwise_rotate_rate = rotate_rate, random_pairwise_rotate_min = settings.random_pairwise_rotate_min, random_pairwise_rotate_max = settings.random_pairwise_rotate_max, - random_pairwise_scale_rate = settings.random_pairwise_scale_rate, + random_pairwise_scale_rate = scale_rate, random_pairwise_scale_min = settings.random_pairwise_scale_min, random_pairwise_scale_max = settings.random_pairwise_scale_max, - random_pairwise_negate_rate = settings.random_pairwise_negate_rate, - random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate, + random_pairwise_negate_rate = negate_rate, + random_pairwise_negate_x_rate = negate_x_rate, pairwise_y_binary = settings.pairwise_y_binary, + pairwise_flip = settings.pairwise_flip, rgb = (settings.color == "rgb")}, meta) return pairwise_transform.user(x, y, settings.crop_size, offset,