2016-09-09 01:35:27 +12:00
|
|
|
require 'cunn'
|
2016-05-13 12:49:53 +12:00
|
|
|
local iproc = require 'iproc'
|
2016-09-11 08:07:42 +12:00
|
|
|
local gm = {}
|
|
|
|
gm.Image = require 'graphicsmagick.Image'
|
2016-05-13 12:49:53 +12:00
|
|
|
local data_augmentation = require 'data_augmentation'
|
|
|
|
local pairwise_transform_utils = {}
|
|
|
|
|
|
|
|
function pairwise_transform_utils.random_half(src, p, filters)
|
|
|
|
if torch.uniform() < p then
|
|
|
|
local filter = filters[torch.random(1, #filters)]
|
|
|
|
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
|
|
|
else
|
|
|
|
return src
|
|
|
|
end
|
|
|
|
end
|
2016-06-06 17:04:13 +12:00
|
|
|
function pairwise_transform_utils.crop_if_large(src, max_size, mod)
|
2016-05-13 12:49:53 +12:00
|
|
|
local tries = 4
|
|
|
|
if src:size(2) > max_size and src:size(3) > max_size then
|
2016-06-06 17:04:13 +12:00
|
|
|
assert(max_size % 4 == 0)
|
2016-05-13 12:49:53 +12:00
|
|
|
local rect
|
|
|
|
for i = 1, tries do
|
|
|
|
local yi = torch.random(0, src:size(2) - max_size)
|
|
|
|
local xi = torch.random(0, src:size(3) - max_size)
|
2016-06-06 17:04:13 +12:00
|
|
|
if mod then
|
|
|
|
yi = yi - (yi % mod)
|
|
|
|
xi = xi - (xi % mod)
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
|
|
|
-- ignore simple background
|
|
|
|
if rect:float():std() >= 0 then
|
|
|
|
break
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return rect
|
|
|
|
else
|
|
|
|
return src
|
|
|
|
end
|
|
|
|
end
|
2016-10-21 19:43:28 +13:00
|
|
|
function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
|
|
|
|
local tries = 4
|
|
|
|
if y:size(2) > max_size and y:size(3) > max_size then
|
|
|
|
assert(max_size % 4 == 0)
|
|
|
|
local rect_x, rect_y
|
|
|
|
for i = 1, tries do
|
|
|
|
local yi = torch.random(0, y:size(2) - max_size)
|
|
|
|
local xi = torch.random(0, y:size(3) - max_size)
|
|
|
|
if mod then
|
|
|
|
yi = yi - (yi % mod)
|
|
|
|
xi = xi - (xi % mod)
|
|
|
|
end
|
|
|
|
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
|
|
|
|
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
|
|
|
|
-- ignore simple background
|
|
|
|
if rect_y:float():std() >= 0 then
|
|
|
|
break
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return rect_x, rect_y
|
|
|
|
else
|
|
|
|
return x, y
|
|
|
|
end
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|
|
|
local dest = src
|
2016-06-06 17:04:13 +12:00
|
|
|
local box_only = false
|
|
|
|
if options.data.filters then
|
|
|
|
if #options.data.filters == 1 and options.data.filters[1] == "Box" then
|
|
|
|
box_only = true
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if box_only then
|
|
|
|
local mod = 2 -- assert pos % 2 == 0
|
|
|
|
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
|
|
|
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
|
|
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
|
|
|
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
2016-06-22 03:37:29 +12:00
|
|
|
dest = iproc.crop_mod4(dest)
|
2016-06-06 17:04:13 +12:00
|
|
|
else
|
|
|
|
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
|
|
|
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
2016-09-24 08:32:33 +12:00
|
|
|
dest = data_augmentation.blur(dest, options.random_blur_rate,
|
|
|
|
options.random_blur_size,
|
2016-10-05 22:48:32 +13:00
|
|
|
options.random_blur_sigma_min,
|
|
|
|
options.random_blur_sigma_max)
|
2016-06-06 17:04:13 +12:00
|
|
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
|
|
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
|
|
|
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
|
|
|
dest = data_augmentation.shift_1px(dest)
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
return dest
|
|
|
|
end
|
2016-10-21 19:43:28 +13:00
|
|
|
function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
|
|
|
|
|
|
|
|
x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
|
|
|
|
x, y = data_augmentation.pairwise_rotate(x, y,
|
|
|
|
options.random_pairwise_rotate_rate,
|
|
|
|
options.random_pairwise_rotate_min,
|
|
|
|
options.random_pairwise_rotate_max)
|
|
|
|
|
|
|
|
local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
|
|
|
|
local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
|
|
|
|
x, y = data_augmentation.pairwise_scale(x, y,
|
|
|
|
options.random_pairwise_scale_rate,
|
|
|
|
scale_min,
|
|
|
|
scale_max)
|
|
|
|
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
|
|
|
|
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
|
|
|
|
|
|
|
|
x = iproc.crop_mod4(x)
|
|
|
|
y = iproc.crop_mod4(y)
|
|
|
|
|
|
|
|
if options.pairwise_y_binary then
|
|
|
|
y[torch.lt(y, 128)] = 0
|
|
|
|
y[torch.gt(y, 0)] = 255
|
|
|
|
end
|
|
|
|
|
|
|
|
return x, y
|
|
|
|
end
|
2016-05-30 22:15:54 +12:00
|
|
|
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
2016-05-13 12:49:53 +12:00
|
|
|
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
|
|
|
assert("crop_size % scale == 0", size % scale == 0)
|
|
|
|
local r = torch.uniform()
|
|
|
|
local t = "float"
|
|
|
|
if x:type() == "torch.ByteTensor" then
|
|
|
|
t = "byte"
|
|
|
|
end
|
|
|
|
if p < r then
|
2017-02-12 21:46:07 +13:00
|
|
|
local xi = 0
|
|
|
|
local yi = 0
|
|
|
|
if x:size(3) > size + 1 then
|
2017-02-19 05:42:40 +13:00
|
|
|
xi = torch.random(0, x:size(3) - (size + 1)) * scale
|
|
|
|
end
|
|
|
|
if x:size(2) > size + 1 then
|
|
|
|
yi = torch.random(0, x:size(2) - (size + 1)) * scale
|
2017-02-12 21:46:07 +13:00
|
|
|
end
|
2016-06-02 13:12:04 +12:00
|
|
|
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
|
|
|
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
2016-05-13 12:49:53 +12:00
|
|
|
return xc, yc
|
|
|
|
else
|
2016-09-09 01:35:27 +12:00
|
|
|
local xcs = torch.LongTensor(tries, y:size(1), size, size)
|
|
|
|
local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
|
|
|
|
local rects = {}
|
|
|
|
local r = torch.LongTensor(2, tries)
|
|
|
|
r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
|
|
|
|
r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
|
2016-05-13 12:49:53 +12:00
|
|
|
for i = 1, tries do
|
2016-09-09 01:35:27 +12:00
|
|
|
local xi = r[1][i]
|
|
|
|
local yi = r[2][i]
|
2016-06-17 23:40:03 +12:00
|
|
|
local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
|
|
|
|
local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
|
2016-09-09 01:35:27 +12:00
|
|
|
xcs[i]:copy(xc)
|
|
|
|
lcs[i]:copy(lc)
|
|
|
|
rects[i] = {xi, yi}
|
2016-05-13 12:49:53 +12:00
|
|
|
end
|
2016-09-09 01:35:27 +12:00
|
|
|
xcs:csub(lcs)
|
|
|
|
xcs:cmul(xcs)
|
|
|
|
local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
|
|
|
|
local best_xi = rects[l[1][1]][1]
|
|
|
|
local best_yi = rects[l[1][1]][2]
|
2016-05-13 12:49:53 +12:00
|
|
|
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
|
|
|
|
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
|
|
|
|
return xc, yc
|
|
|
|
end
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|
|
|
local xs = {}
|
|
|
|
local ns = {}
|
|
|
|
local ys = {}
|
|
|
|
local ls = {}
|
|
|
|
|
|
|
|
for j = 1, 2 do
|
|
|
|
-- TTA
|
2016-11-03 03:41:56 +13:00
|
|
|
local xi, yi, ri, ni
|
2016-06-18 23:25:15 +12:00
|
|
|
if j == 1 then
|
|
|
|
xi = x
|
|
|
|
ni = x_noise
|
|
|
|
yi = y
|
|
|
|
ri = lowres_y
|
|
|
|
else
|
|
|
|
xi = x:transpose(2, 3):contiguous()
|
|
|
|
if x_noise then
|
|
|
|
ni = x_noise:transpose(2, 3):contiguous()
|
|
|
|
end
|
|
|
|
yi = y:transpose(2, 3):contiguous()
|
2016-11-03 03:41:56 +13:00
|
|
|
if lowres_y then
|
|
|
|
ri = lowres_y:transpose(2, 3):contiguous()
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
end
|
2016-09-11 08:07:42 +12:00
|
|
|
local xv = iproc.vflip(xi)
|
2016-06-18 23:25:15 +12:00
|
|
|
local nv
|
|
|
|
if x_noise then
|
2016-09-11 08:07:42 +12:00
|
|
|
nv = iproc.vflip(ni)
|
2016-06-18 23:25:15 +12:00
|
|
|
end
|
2016-09-11 08:07:42 +12:00
|
|
|
local yv = iproc.vflip(yi)
|
2016-11-03 03:41:56 +13:00
|
|
|
local rv
|
|
|
|
if ri then
|
|
|
|
rv = iproc.vflip(ri)
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
table.insert(xs, xi)
|
|
|
|
if ni then
|
|
|
|
table.insert(ns, ni)
|
|
|
|
end
|
|
|
|
table.insert(ys, yi)
|
2016-11-03 03:41:56 +13:00
|
|
|
if ri then
|
|
|
|
table.insert(ls, ri)
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
|
|
|
|
table.insert(xs, xv)
|
|
|
|
if nv then
|
|
|
|
table.insert(ns, nv)
|
|
|
|
end
|
|
|
|
table.insert(ys, yv)
|
2016-11-03 03:41:56 +13:00
|
|
|
if rv then
|
|
|
|
table.insert(ls, rv)
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
|
2016-09-11 08:07:42 +12:00
|
|
|
table.insert(xs, iproc.hflip(xi))
|
2016-06-18 23:25:15 +12:00
|
|
|
if ni then
|
2016-09-11 08:07:42 +12:00
|
|
|
table.insert(ns, iproc.hflip(ni))
|
2016-06-18 23:25:15 +12:00
|
|
|
end
|
2016-09-11 08:07:42 +12:00
|
|
|
table.insert(ys, iproc.hflip(yi))
|
2016-11-03 03:41:56 +13:00
|
|
|
if ri then
|
|
|
|
table.insert(ls, iproc.hflip(ri))
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
|
2016-09-11 08:07:42 +12:00
|
|
|
table.insert(xs, iproc.hflip(xv))
|
2016-06-18 23:25:15 +12:00
|
|
|
if nv then
|
2016-09-11 08:07:42 +12:00
|
|
|
table.insert(ns, iproc.hflip(nv))
|
2016-06-18 23:25:15 +12:00
|
|
|
end
|
2016-09-11 08:07:42 +12:00
|
|
|
table.insert(ys, iproc.hflip(yv))
|
2016-11-03 03:41:56 +13:00
|
|
|
if rv then
|
|
|
|
table.insert(ls, iproc.hflip(rv))
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
end
|
|
|
|
return xs, ys, ls, ns
|
|
|
|
end
|
2016-09-09 01:35:27 +12:00
|
|
|
local function lowres_model()
|
|
|
|
local seq = nn.Sequential()
|
|
|
|
seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
|
|
|
|
seq:add(nn.SpatialUpSamplingNearest(2))
|
|
|
|
return seq:cuda()
|
|
|
|
end
|
|
|
|
local g_lowres_model = nil
|
|
|
|
local g_lowres_gpu = nil
|
|
|
|
function pairwise_transform_utils.low_resolution(src)
|
2016-09-11 08:07:42 +12:00
|
|
|
--[[
|
|
|
|
-- I am not sure that the following process is thraed-safe
|
|
|
|
|
2016-09-09 01:35:27 +12:00
|
|
|
g_lowres_model = g_lowres_model or lowres_model()
|
|
|
|
if g_lowres_gpu == nil then
|
|
|
|
--benchmark
|
|
|
|
local gpu_time = sys.clock()
|
|
|
|
for i = 1, 10 do
|
|
|
|
g_lowres_model:forward(src:cuda()):byte()
|
|
|
|
end
|
|
|
|
gpu_time = sys.clock() - gpu_time
|
|
|
|
|
|
|
|
local cpu_time = sys.clock()
|
|
|
|
for i = 1, 10 do
|
|
|
|
gm.Image(src, "RGB", "DHW"):
|
|
|
|
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
|
|
|
size(src:size(3), src:size(2), "Box"):
|
|
|
|
toTensor("byte", "RGB", "DHW")
|
|
|
|
end
|
|
|
|
cpu_time = sys.clock() - cpu_time
|
|
|
|
--print(gpu_time, cpu_time)
|
|
|
|
if gpu_time < cpu_time then
|
|
|
|
g_lowres_gpu = true
|
|
|
|
else
|
|
|
|
g_lowres_gpu = false
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if g_lowres_gpu then
|
|
|
|
return g_lowres_model:forward(src:cuda()):byte()
|
|
|
|
else
|
|
|
|
return gm.Image(src, "RGB", "DHW"):
|
|
|
|
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
|
|
|
size(src:size(3), src:size(2), "Box"):
|
|
|
|
toTensor("byte", "RGB", "DHW")
|
|
|
|
end
|
2016-09-11 08:07:42 +12:00
|
|
|
--]]
|
|
|
|
return gm.Image(src, "RGB", "DHW"):
|
|
|
|
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
|
|
|
|
size(src:size(3), src:size(2), "Box"):
|
|
|
|
toTensor("byte", "RGB", "DHW")
|
2016-09-09 01:35:27 +12:00
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
|
|
|
|
return pairwise_transform_utils
|