perfomance tuning
This commit is contained in:
parent
6220afc31a
commit
bdaca16c67
|
@ -178,7 +178,7 @@ local function rotate_with_warp(src, dst, theta, mode)
|
|||
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
|
||||
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
|
||||
dst:resizeAs(src)
|
||||
return image.warp(dst, src, flow, mode, true, 'pad')
|
||||
return image.warp(dst, src, flow, mode, true, 'clamp')
|
||||
end
|
||||
function iproc.rotate(src, theta)
|
||||
local conversion
|
||||
|
@ -212,6 +212,16 @@ function iproc.gaussian2d(kernel_size, sigma)
|
|||
kernel:div(kernel:sum())
|
||||
return kernel
|
||||
end
|
||||
function iproc.rgb2y(src)
|
||||
local conversion
|
||||
src, conversion = iproc.byte2float(src)
|
||||
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
|
||||
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
|
||||
local function test_conversion()
|
||||
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
||||
|
|
|
@ -13,8 +13,18 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
|
||||
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
|
||||
local batch = {}
|
||||
local lowres_y = pairwise_utils.low_resolution(y)
|
||||
local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||
local lowres_y = nil
|
||||
local xs ={x}
|
||||
local ys = {y}
|
||||
local ls = {}
|
||||
|
||||
if options.active_cropping_rate > 0 then
|
||||
lowres_y = pairwise_utils.low_resolution(y)
|
||||
end
|
||||
if options.pairwise_flip then
|
||||
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||
end
|
||||
assert(#xs == #ys)
|
||||
for i = 1, n do
|
||||
local t = (i % #xs) + 1
|
||||
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
||||
|
@ -24,8 +34,8 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
|||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
yc = iproc.rgb2y(yc)
|
||||
xc = iproc.rgb2y(xc)
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
|
|
|
@ -164,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|||
|
||||
for j = 1, 2 do
|
||||
-- TTA
|
||||
local xi, yi, ri
|
||||
local xi, yi, ri, ni
|
||||
if j == 1 then
|
||||
xi = x
|
||||
ni = x_noise
|
||||
|
@ -176,7 +176,9 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|||
ni = x_noise:transpose(2, 3):contiguous()
|
||||
end
|
||||
yi = y:transpose(2, 3):contiguous()
|
||||
ri = lowres_y:transpose(2, 3):contiguous()
|
||||
if lowres_y then
|
||||
ri = lowres_y:transpose(2, 3):contiguous()
|
||||
end
|
||||
end
|
||||
local xv = iproc.vflip(xi)
|
||||
local nv
|
||||
|
@ -184,34 +186,45 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|||
nv = iproc.vflip(ni)
|
||||
end
|
||||
local yv = iproc.vflip(yi)
|
||||
local rv = iproc.vflip(ri)
|
||||
local rv
|
||||
if ri then
|
||||
rv = iproc.vflip(ri)
|
||||
end
|
||||
table.insert(xs, xi)
|
||||
if ni then
|
||||
table.insert(ns, ni)
|
||||
end
|
||||
table.insert(ys, yi)
|
||||
table.insert(ls, ri)
|
||||
if ri then
|
||||
table.insert(ls, ri)
|
||||
end
|
||||
|
||||
table.insert(xs, xv)
|
||||
if nv then
|
||||
table.insert(ns, nv)
|
||||
end
|
||||
table.insert(ys, yv)
|
||||
table.insert(ls, rv)
|
||||
if rv then
|
||||
table.insert(ls, rv)
|
||||
end
|
||||
|
||||
table.insert(xs, iproc.hflip(xi))
|
||||
if ni then
|
||||
table.insert(ns, iproc.hflip(ni))
|
||||
end
|
||||
table.insert(ys, iproc.hflip(yi))
|
||||
table.insert(ls, iproc.hflip(ri))
|
||||
if ri then
|
||||
table.insert(ls, iproc.hflip(ri))
|
||||
end
|
||||
|
||||
table.insert(xs, iproc.hflip(xv))
|
||||
if nv then
|
||||
table.insert(ns, iproc.hflip(nv))
|
||||
end
|
||||
table.insert(ys, iproc.hflip(yv))
|
||||
table.insert(ls, iproc.hflip(rv))
|
||||
if rv then
|
||||
table.insert(ls, iproc.hflip(rv))
|
||||
end
|
||||
end
|
||||
return xs, ys, ls, ns
|
||||
end
|
||||
|
|
|
@ -46,6 +46,7 @@ cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwi
|
|||
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
||||
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
||||
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
||||
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-crop_size", 48, 'crop size')
|
||||
|
@ -91,6 +92,7 @@ to_bool(settings, "plot")
|
|||
to_bool(settings, "save_history")
|
||||
to_bool(settings, "use_transparent_png")
|
||||
to_bool(settings, "pairwise_y_binary")
|
||||
to_bool(settings, "pairwise_flip")
|
||||
|
||||
if settings.plot then
|
||||
require 'gnuplot'
|
||||
|
|
|
@ -466,9 +466,6 @@ function srcnn.fcn_v1(backend, ch)
|
|||
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(nn.Dropout(0.5, false, true))
|
||||
model:add(SpatialConvolution(backend, 256, 256, 1, 1, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
model:add(nn.Dropout(0.5, false, true))
|
||||
|
||||
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1, true))
|
||||
|
|
25
train.lua
25
train.lua
|
@ -175,19 +175,36 @@ local function transform_pool_init(has_resize, offset)
|
|||
settings.crop_size, offset,
|
||||
n, conf)
|
||||
elseif settings.method == "user" then
|
||||
if is_validation == nil then is_validation = false end
|
||||
local rotate_rate = nil
|
||||
local scale_rate = nil
|
||||
local negate_rate = nil
|
||||
local negate_x_rate = nil
|
||||
if is_validation then
|
||||
rotate_rate = 0
|
||||
scale_rate = 0
|
||||
negate_rate = 0
|
||||
negate_x_rate = 0
|
||||
else
|
||||
rotate_rate = settings.random_pairwise_rotate_rate
|
||||
scale_rate = settings.random_pairwise_scale_rate
|
||||
negate_rate = settings.random_pairwise_negate_rate
|
||||
negate_x_rate = settings.random_pairwise_negate_x_rate
|
||||
end
|
||||
local conf = tablex.update({
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
|
||||
random_pairwise_rotate_rate = rotate_rate,
|
||||
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
||||
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
||||
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
|
||||
random_pairwise_scale_rate = scale_rate,
|
||||
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
||||
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
||||
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
|
||||
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
||||
random_pairwise_negate_rate = negate_rate,
|
||||
random_pairwise_negate_x_rate = negate_x_rate,
|
||||
pairwise_y_binary = settings.pairwise_y_binary,
|
||||
pairwise_flip = settings.pairwise_flip,
|
||||
rgb = (settings.color == "rgb")}, meta)
|
||||
return pairwise_transform.user(x, y,
|
||||
settings.crop_size, offset,
|
||||
|
|
Loading…
Reference in a new issue