perfomance tuning
This commit is contained in:
parent
6220afc31a
commit
bdaca16c67
|
@ -178,7 +178,7 @@ local function rotate_with_warp(src, dst, theta, mode)
|
||||||
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
|
flow[2]:mul(-(width -1)):add(math.floor(width / 2 + 0.5))
|
||||||
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
|
flow:add(-1, torch.mm(kernel, flow:view(2, height * width)))
|
||||||
dst:resizeAs(src)
|
dst:resizeAs(src)
|
||||||
return image.warp(dst, src, flow, mode, true, 'pad')
|
return image.warp(dst, src, flow, mode, true, 'clamp')
|
||||||
end
|
end
|
||||||
function iproc.rotate(src, theta)
|
function iproc.rotate(src, theta)
|
||||||
local conversion
|
local conversion
|
||||||
|
@ -212,6 +212,16 @@ function iproc.gaussian2d(kernel_size, sigma)
|
||||||
kernel:div(kernel:sum())
|
kernel:div(kernel:sum())
|
||||||
return kernel
|
return kernel
|
||||||
end
|
end
|
||||||
|
function iproc.rgb2y(src)
|
||||||
|
local conversion
|
||||||
|
src, conversion = iproc.byte2float(src)
|
||||||
|
local dest = torch.FloatTensor(1, src:size(2), src:size(3)):zero()
|
||||||
|
dest:add(0.299, src[1]):add(0.587, src[2]):add(0.114, src[3])
|
||||||
|
if conversion then
|
||||||
|
dest = iproc.float2byte(dest)
|
||||||
|
end
|
||||||
|
return dest
|
||||||
|
end
|
||||||
|
|
||||||
local function test_conversion()
|
local function test_conversion()
|
||||||
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
||||||
|
|
|
@ -13,8 +13,18 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
||||||
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
|
x, y = pairwise_utils.preprocess_user(x, y, scale_y, size, options)
|
||||||
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
|
assert(x:size(3) == y:size(3) / scale_y and x:size(2) == y:size(2) / scale_y)
|
||||||
local batch = {}
|
local batch = {}
|
||||||
local lowres_y = pairwise_utils.low_resolution(y)
|
local lowres_y = nil
|
||||||
local xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
local xs ={x}
|
||||||
|
local ys = {y}
|
||||||
|
local ls = {}
|
||||||
|
|
||||||
|
if options.active_cropping_rate > 0 then
|
||||||
|
lowres_y = pairwise_utils.low_resolution(y)
|
||||||
|
end
|
||||||
|
if options.pairwise_flip then
|
||||||
|
xs, ys, ls = pairwise_utils.flip_augmentation(x, y, lowres_y)
|
||||||
|
end
|
||||||
|
assert(#xs == #ys)
|
||||||
for i = 1, n do
|
for i = 1, n do
|
||||||
local t = (i % #xs) + 1
|
local t = (i % #xs) + 1
|
||||||
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
local xc, yc = pairwise_utils.active_cropping(xs[t], ys[t], ls[t], size, scale_y,
|
||||||
|
@ -24,8 +34,8 @@ function pairwise_transform.user(x, y, size, offset, n, options)
|
||||||
yc = iproc.byte2float(yc)
|
yc = iproc.byte2float(yc)
|
||||||
if options.rgb then
|
if options.rgb then
|
||||||
else
|
else
|
||||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
yc = iproc.rgb2y(yc)
|
||||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
xc = iproc.rgb2y(xc)
|
||||||
end
|
end
|
||||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||||
end
|
end
|
||||||
|
|
|
@ -164,7 +164,7 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
||||||
|
|
||||||
for j = 1, 2 do
|
for j = 1, 2 do
|
||||||
-- TTA
|
-- TTA
|
||||||
local xi, yi, ri
|
local xi, yi, ri, ni
|
||||||
if j == 1 then
|
if j == 1 then
|
||||||
xi = x
|
xi = x
|
||||||
ni = x_noise
|
ni = x_noise
|
||||||
|
@ -176,7 +176,9 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
||||||
ni = x_noise:transpose(2, 3):contiguous()
|
ni = x_noise:transpose(2, 3):contiguous()
|
||||||
end
|
end
|
||||||
yi = y:transpose(2, 3):contiguous()
|
yi = y:transpose(2, 3):contiguous()
|
||||||
ri = lowres_y:transpose(2, 3):contiguous()
|
if lowres_y then
|
||||||
|
ri = lowres_y:transpose(2, 3):contiguous()
|
||||||
|
end
|
||||||
end
|
end
|
||||||
local xv = iproc.vflip(xi)
|
local xv = iproc.vflip(xi)
|
||||||
local nv
|
local nv
|
||||||
|
@ -184,34 +186,45 @@ function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
||||||
nv = iproc.vflip(ni)
|
nv = iproc.vflip(ni)
|
||||||
end
|
end
|
||||||
local yv = iproc.vflip(yi)
|
local yv = iproc.vflip(yi)
|
||||||
local rv = iproc.vflip(ri)
|
local rv
|
||||||
|
if ri then
|
||||||
|
rv = iproc.vflip(ri)
|
||||||
|
end
|
||||||
table.insert(xs, xi)
|
table.insert(xs, xi)
|
||||||
if ni then
|
if ni then
|
||||||
table.insert(ns, ni)
|
table.insert(ns, ni)
|
||||||
end
|
end
|
||||||
table.insert(ys, yi)
|
table.insert(ys, yi)
|
||||||
table.insert(ls, ri)
|
if ri then
|
||||||
|
table.insert(ls, ri)
|
||||||
|
end
|
||||||
|
|
||||||
table.insert(xs, xv)
|
table.insert(xs, xv)
|
||||||
if nv then
|
if nv then
|
||||||
table.insert(ns, nv)
|
table.insert(ns, nv)
|
||||||
end
|
end
|
||||||
table.insert(ys, yv)
|
table.insert(ys, yv)
|
||||||
table.insert(ls, rv)
|
if rv then
|
||||||
|
table.insert(ls, rv)
|
||||||
|
end
|
||||||
|
|
||||||
table.insert(xs, iproc.hflip(xi))
|
table.insert(xs, iproc.hflip(xi))
|
||||||
if ni then
|
if ni then
|
||||||
table.insert(ns, iproc.hflip(ni))
|
table.insert(ns, iproc.hflip(ni))
|
||||||
end
|
end
|
||||||
table.insert(ys, iproc.hflip(yi))
|
table.insert(ys, iproc.hflip(yi))
|
||||||
table.insert(ls, iproc.hflip(ri))
|
if ri then
|
||||||
|
table.insert(ls, iproc.hflip(ri))
|
||||||
|
end
|
||||||
|
|
||||||
table.insert(xs, iproc.hflip(xv))
|
table.insert(xs, iproc.hflip(xv))
|
||||||
if nv then
|
if nv then
|
||||||
table.insert(ns, iproc.hflip(nv))
|
table.insert(ns, iproc.hflip(nv))
|
||||||
end
|
end
|
||||||
table.insert(ys, iproc.hflip(yv))
|
table.insert(ys, iproc.hflip(yv))
|
||||||
table.insert(ls, iproc.hflip(rv))
|
if rv then
|
||||||
|
table.insert(ls, iproc.hflip(rv))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
return xs, ys, ls, ns
|
return xs, ys, ls, ns
|
||||||
end
|
end
|
||||||
|
|
|
@ -46,6 +46,7 @@ cmd:option("-random_pairwise_rotate_max", 6, 'max rotate angle for random pairwi
|
||||||
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
cmd:option("-random_pairwise_negate_rate", 0.0, 'data augmentation using nagate image for user method')
|
||||||
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
cmd:option("-random_pairwise_negate_x_rate", 0.0, 'data augmentation using nagate image only x side for user method')
|
||||||
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
cmd:option("-pairwise_y_binary", 0, 'binarize y after data augmentation(0|1)')
|
||||||
|
cmd:option("-pairwise_flip", 1, 'use flip(0|1)')
|
||||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||||
cmd:option("-crop_size", 48, 'crop size')
|
cmd:option("-crop_size", 48, 'crop size')
|
||||||
|
@ -91,6 +92,7 @@ to_bool(settings, "plot")
|
||||||
to_bool(settings, "save_history")
|
to_bool(settings, "save_history")
|
||||||
to_bool(settings, "use_transparent_png")
|
to_bool(settings, "use_transparent_png")
|
||||||
to_bool(settings, "pairwise_y_binary")
|
to_bool(settings, "pairwise_y_binary")
|
||||||
|
to_bool(settings, "pairwise_flip")
|
||||||
|
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
require 'gnuplot'
|
require 'gnuplot'
|
||||||
|
|
|
@ -466,9 +466,6 @@ function srcnn.fcn_v1(backend, ch)
|
||||||
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
|
model:add(SpatialConvolution(backend, 128, 256, 1, 1, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1, true))
|
model:add(nn.LeakyReLU(0.1, true))
|
||||||
model:add(nn.Dropout(0.5, false, true))
|
model:add(nn.Dropout(0.5, false, true))
|
||||||
model:add(SpatialConvolution(backend, 256, 256, 1, 1, 1, 1, 0, 0))
|
|
||||||
model:add(nn.LeakyReLU(0.1, true))
|
|
||||||
model:add(nn.Dropout(0.5, false, true))
|
|
||||||
|
|
||||||
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
|
model:add(SpatialFullConvolution(backend, 256, 128, 2, 2, 2, 2, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1, true))
|
model:add(nn.LeakyReLU(0.1, true))
|
||||||
|
|
25
train.lua
25
train.lua
|
@ -175,19 +175,36 @@ local function transform_pool_init(has_resize, offset)
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
n, conf)
|
n, conf)
|
||||||
elseif settings.method == "user" then
|
elseif settings.method == "user" then
|
||||||
|
if is_validation == nil then is_validation = false end
|
||||||
|
local rotate_rate = nil
|
||||||
|
local scale_rate = nil
|
||||||
|
local negate_rate = nil
|
||||||
|
local negate_x_rate = nil
|
||||||
|
if is_validation then
|
||||||
|
rotate_rate = 0
|
||||||
|
scale_rate = 0
|
||||||
|
negate_rate = 0
|
||||||
|
negate_x_rate = 0
|
||||||
|
else
|
||||||
|
rotate_rate = settings.random_pairwise_rotate_rate
|
||||||
|
scale_rate = settings.random_pairwise_scale_rate
|
||||||
|
negate_rate = settings.random_pairwise_negate_rate
|
||||||
|
negate_x_rate = settings.random_pairwise_negate_x_rate
|
||||||
|
end
|
||||||
local conf = tablex.update({
|
local conf = tablex.update({
|
||||||
max_size = settings.max_size,
|
max_size = settings.max_size,
|
||||||
active_cropping_rate = active_cropping_rate,
|
active_cropping_rate = active_cropping_rate,
|
||||||
active_cropping_tries = active_cropping_tries,
|
active_cropping_tries = active_cropping_tries,
|
||||||
random_pairwise_rotate_rate = settings.random_pairwise_rotate_rate,
|
random_pairwise_rotate_rate = rotate_rate,
|
||||||
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
random_pairwise_rotate_min = settings.random_pairwise_rotate_min,
|
||||||
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
random_pairwise_rotate_max = settings.random_pairwise_rotate_max,
|
||||||
random_pairwise_scale_rate = settings.random_pairwise_scale_rate,
|
random_pairwise_scale_rate = scale_rate,
|
||||||
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
random_pairwise_scale_min = settings.random_pairwise_scale_min,
|
||||||
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
random_pairwise_scale_max = settings.random_pairwise_scale_max,
|
||||||
random_pairwise_negate_rate = settings.random_pairwise_negate_rate,
|
random_pairwise_negate_rate = negate_rate,
|
||||||
random_pairwise_negate_x_rate = settings.random_pairwise_negate_x_rate,
|
random_pairwise_negate_x_rate = negate_x_rate,
|
||||||
pairwise_y_binary = settings.pairwise_y_binary,
|
pairwise_y_binary = settings.pairwise_y_binary,
|
||||||
|
pairwise_flip = settings.pairwise_flip,
|
||||||
rgb = (settings.color == "rgb")}, meta)
|
rgb = (settings.color == "rgb")}, meta)
|
||||||
return pairwise_transform.user(x, y,
|
return pairwise_transform.user(x, y,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
|
|
Loading…
Reference in a new issue