1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

perfomance tuning

This commit is contained in:
nagadomi 2016-11-02 23:41:56 +09:00
parent 6220afc31a
commit bdaca16c67
6 changed files with 68 additions and 19 deletions

View file

@ -178,7 +178,7 @@ local function rotate_with_warp(src, dst, theta, mode)
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
dst:resizeAs(src)
return image.warp(dst, src, flow, mode, true, 'pad')
return image.warp(dst, src, flow, mode, true, 'clamp')
end
function iproc.rotate(src, theta)
local conversion
@ -212,6 +212,16 @@ function iproc.gaussian2d(kernel_size, sigma)
kernel:div(kernel:sum())
return kernel
end
function iproc.rgb2y(src)
local conversion
src, conversion = iproc.byte2float(src)
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
local function test_conversion()
local a = torch.linspace(0, 255, 256):float():div(255.0)

View file

@ -13,8 +13,18 @@ function pairwise_transform.user(x, y, size, offset, n, options)
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
local batch = {}
local lowres_y = pairwise_utils.low_resolution(y)
local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
local lowres_y = nil
local xs ={x}
local ys = {y}
local ls = {}
if options.active_cropping_rate > 0 then
lowres_y = pairwise_utils.low_resolution(y)
end
if options.pairwise_flip then
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
end
assert(#xs == #ys)
for i = 1, n do
local t = (i % #xs) + 1
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
@ -24,8 +34,8 @@ function pairwise_transform.user(x, y, size, offset, n, options)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
yc = iproc.rgb2y(yc)
xc = iproc.rgb2y(xc)
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end

View file

@ -164,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
for j = 1, 2 do
-- TTA
local xi, yi, ri
local xi, yi, ri, ni
if j == 1 then
xi = x
ni = x_noise
@ -176,7 +176,9 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
ni = x_noise:transpose(2, 3):contiguous()
end
yi = y:transpose(2, 3):contiguous()
ri = lowres_y:transpose(2, 3):contiguous()
if lowres_y then
ri = lowres_y:transpose(2, 3):contiguous()
end
end
local xv = iproc.vflip(xi)
local nv
@ -184,34 +186,45 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
nv = iproc.vflip(ni)
end
local yv = iproc.vflip(yi)
local rv = iproc.vflip(ri)
local rv
if ri then
rv = iproc.vflip(ri)
end
table.insert(xs, xi)
if ni then
table.insert(ns, ni)
end
table.insert(ys, yi)
table.insert(ls, ri)
if ri then
table.insert(ls, ri)
end
table.insert(xs, xv)
if nv then
table.insert(ns, nv)
end
table.insert(ys, yv)
table.insert(ls, rv)
if rv then
table.insert(ls, rv)
end
table.insert(xs, iproc.hflip(xi))
if ni then
table.insert(ns, iproc.hflip(ni))
end
table.insert(ys, iproc.hflip(yi))
table.insert(ls, iproc.hflip(ri))
if ri then
table.insert(ls, iproc.hflip(ri))
end
table.insert(xs, iproc.hflip(xv))
if nv then
table.insert(ns, iproc.hflip(nv))
end
table.insert(ys, iproc.hflip(yv))
table.insert(ls, iproc.hflip(rv))
if rv then
table.insert(ls, iproc.hflip(rv))
end
end
return xs, ys, ls, ns
end

View file

@ -46,6 +46,7 @@ cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwi
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-crop_size", 48, 'crop size')
@ -91,6 +92,7 @@ to_bool(settings, "plot")
to_bool(settings, "save_history")
to_bool(settings, "use_transparent_png")
to_bool(settings, "pairwise_y_binary")
to_bool(settings, "pairwise_flip")
if settings.plot then
require 'gnuplot'

View file

@ -466,9 +466,6 @@ function srcnn.fcn_v1(backend, ch)
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.Dropout(0.5, false, true))
model:add(SpatialConvolution(backend, 256, 256, 1, 1, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.Dropout(0.5, false, true))
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
model:add(nn.LeakyReLU(0.1, true))

View file

@ -175,19 +175,36 @@ local function transform_pool_init(has_resize, offset)
settings.crop_size, offset,
n, conf)
elseif settings.method == "user" then
if is_validation == nil then is_validation = false end
local rotate_rate = nil
local scale_rate = nil
local negate_rate = nil
local negate_x_rate = nil
if is_validation then
rotate_rate = 0
scale_rate = 0
negate_rate = 0
negate_x_rate = 0
else
rotate_rate = settings.random_pairwise_rotate_rate
scale_rate = settings.random_pairwise_scale_rate
negate_rate = settings.random_pairwise_negate_rate
negate_x_rate = settings.random_pairwise_negate_x_rate
end
local conf = tablex.update({
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
random_pairwise_rotate_rate = rotate_rate,
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
random_pairwise_scale_rate = scale_rate,
random_pairwise_scale_min = settings.random_pairwise_scale_min,
random_pairwise_scale_max = settings.random_pairwise_scale_max,
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
random_pairwise_negate_rate = negate_rate,
random_pairwise_negate_x_rate = negate_x_rate,
pairwise_y_binary = settings.pairwise_y_binary,
pairwise_flip = settings.pairwise_flip,
rgb = (settings.color == "rgb")}, meta)
return pairwise_transform.user(x, y,
settings.crop_size, offset,