1
0
Fork 0
mirror of synced 2024-06-13 08:24:30 +12:00

perfomance tuning

This commit is contained in:
nagadomi 2016-11-02 23:41:56 +09:00
parent 6220afc31a
commit bdaca16c67
6 changed files with 68 additions and 19 deletions

View file

@ -178,7 +178,7 @@ local function rotate_with_warp(src, dst, theta, mode)
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5)) flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
flow:add(-1, torch.mm(kernel, flow:view(2, height * width))) flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
dst:resizeAs(src) dst:resizeAs(src)
return image.warp(dst, src, flow, mode, true, 'pad') return image.warp(dst, src, flow, mode, true, 'clamp')
end end
function iproc.rotate(src, theta) function iproc.rotate(src, theta)
local conversion local conversion
@ -212,6 +212,16 @@ function iproc.gaussian2d(kernel_size, sigma)
kernel:div(kernel:sum()) kernel:div(kernel:sum())
return kernel return kernel
end end
function iproc.rgb2y(src)
local conversion
src, conversion = iproc.byte2float(src)
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
local function test_conversion() local function test_conversion()
local a = torch.linspace(0, 255, 256):float():div(255.0) local a = torch.linspace(0, 255, 256):float():div(255.0)

View file

@ -13,8 +13,18 @@ function pairwise_transform.user(x, y, size, offset, n, options)
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options) x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y) assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
local batch = {} local batch = {}
local lowres_y = pairwise_utils.low_resolution(y) local lowres_y = nil
local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y) local xs ={x}
local ys = {y}
local ls = {}
if options.active_cropping_rate > 0 then
lowres_y = pairwise_utils.low_resolution(y)
end
if options.pairwise_flip then
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
end
assert(#xs == #ys)
for i = 1, n do for i = 1, n do
local t = (i % #xs) + 1 local t = (i % #xs) + 1
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y, local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
@ -24,8 +34,8 @@ function pairwise_transform.user(x, y, size, offset, n, options)
yc = iproc.byte2float(yc) yc = iproc.byte2float(yc)
if options.rgb then if options.rgb then
else else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) yc = iproc.rgb2y(yc)
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) xc = iproc.rgb2y(xc)
end end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end end

View file

@ -164,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
for j = 1, 2 do for j = 1, 2 do
-- TTA -- TTA
local xi, yi, ri local xi, yi, ri, ni
if j == 1 then if j == 1 then
xi = x xi = x
ni = x_noise ni = x_noise
@ -176,7 +176,9 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
ni = x_noise:transpose(2, 3):contiguous() ni = x_noise:transpose(2, 3):contiguous()
end end
yi = y:transpose(2, 3):contiguous() yi = y:transpose(2, 3):contiguous()
ri = lowres_y:transpose(2, 3):contiguous() if lowres_y then
ri = lowres_y:transpose(2, 3):contiguous()
end
end end
local xv = iproc.vflip(xi) local xv = iproc.vflip(xi)
local nv local nv
@ -184,34 +186,45 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
nv = iproc.vflip(ni) nv = iproc.vflip(ni)
end end
local yv = iproc.vflip(yi) local yv = iproc.vflip(yi)
local rv = iproc.vflip(ri) local rv
if ri then
rv = iproc.vflip(ri)
end
table.insert(xs, xi) table.insert(xs, xi)
if ni then if ni then
table.insert(ns, ni) table.insert(ns, ni)
end end
table.insert(ys, yi) table.insert(ys, yi)
table.insert(ls, ri) if ri then
table.insert(ls, ri)
end
table.insert(xs, xv) table.insert(xs, xv)
if nv then if nv then
table.insert(ns, nv) table.insert(ns, nv)
end end
table.insert(ys, yv) table.insert(ys, yv)
table.insert(ls, rv) if rv then
table.insert(ls, rv)
end
table.insert(xs, iproc.hflip(xi)) table.insert(xs, iproc.hflip(xi))
if ni then if ni then
table.insert(ns, iproc.hflip(ni)) table.insert(ns, iproc.hflip(ni))
end end
table.insert(ys, iproc.hflip(yi)) table.insert(ys, iproc.hflip(yi))
table.insert(ls, iproc.hflip(ri)) if ri then
table.insert(ls, iproc.hflip(ri))
end
table.insert(xs, iproc.hflip(xv)) table.insert(xs, iproc.hflip(xv))
if nv then if nv then
table.insert(ns, iproc.hflip(nv)) table.insert(ns, iproc.hflip(nv))
end end
table.insert(ys, iproc.hflip(yv)) table.insert(ys, iproc.hflip(yv))
table.insert(ls, iproc.hflip(rv)) if rv then
table.insert(ls, iproc.hflip(rv))
end
end end
return xs, ys, ls, ns return xs, ys, ls, ns
end end

View file

@ -46,6 +46,7 @@ cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwi
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method') cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method') cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)') cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-crop_size", 48, 'crop size') cmd:option("-crop_size", 48, 'crop size')
@ -91,6 +92,7 @@ to_bool(settings, "plot")
to_bool(settings, "save_history") to_bool(settings, "save_history")
to_bool(settings, "use_transparent_png") to_bool(settings, "use_transparent_png")
to_bool(settings, "pairwise_y_binary") to_bool(settings, "pairwise_y_binary")
to_bool(settings, "pairwise_flip")
if settings.plot then if settings.plot then
require 'gnuplot' require 'gnuplot'

View file

@ -466,9 +466,6 @@ function srcnn.fcn_v1(backend, ch)
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0)) model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true)) model:add(nn.LeakyReLU(0.1, true))
model:add(nn.Dropout(0.5, false, true)) model:add(nn.Dropout(0.5, false, true))
model:add(SpatialConvolution(backend, 256, 256, 1, 1, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(nn.Dropout(0.5, false, true))
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0)) model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
model:add(nn.LeakyReLU(0.1, true)) model:add(nn.LeakyReLU(0.1, true))

View file

@ -175,19 +175,36 @@ local function transform_pool_init(has_resize, offset)
settings.crop_size, offset, settings.crop_size, offset,
n, conf) n, conf)
elseif settings.method == "user" then elseif settings.method == "user" then
if is_validation == nil then is_validation = false end
local rotate_rate = nil
local scale_rate = nil
local negate_rate = nil
local negate_x_rate = nil
if is_validation then
rotate_rate = 0
scale_rate = 0
negate_rate = 0
negate_x_rate = 0
else
rotate_rate = settings.random_pairwise_rotate_rate
scale_rate = settings.random_pairwise_scale_rate
negate_rate = settings.random_pairwise_negate_rate
negate_x_rate = settings.random_pairwise_negate_x_rate
end
local conf = tablex.update({ local conf = tablex.update({
max_size = settings.max_size, max_size = settings.max_size,
active_cropping_rate = active_cropping_rate, active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries, active_cropping_tries = active_cropping_tries,
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate, random_pairwise_rotate_rate = rotate_rate,
random_pairwise_rotate_min = settings.random_pairwise_rotate_min, random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
random_pairwise_rotate_max = settings.random_pairwise_rotate_max, random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
random_pairwise_scale_rate = settings.random_pairwise_scale_rate, random_pairwise_scale_rate = scale_rate,
random_pairwise_scale_min = settings.random_pairwise_scale_min, random_pairwise_scale_min = settings.random_pairwise_scale_min,
random_pairwise_scale_max = settings.random_pairwise_scale_max, random_pairwise_scale_max = settings.random_pairwise_scale_max,
random_pairwise_negate_rate = settings.random_pairwise_negate_rate, random_pairwise_negate_rate = negate_rate,
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate, random_pairwise_negate_x_rate = negate_x_rate,
pairwise_y_binary = settings.pairwise_y_binary, pairwise_y_binary = settings.pairwise_y_binary,
pairwise_flip = settings.pairwise_flip,
rgb = (settings.color == "rgb")}, meta) rgb = (settings.color == "rgb")}, meta)
return pairwise_transform.user(x, y, return pairwise_transform.user(x, y,
settings.crop_size, offset, settings.crop_size, offset,