2016-05-13 12:49:53 +12:00
|
|
|
require 'image'
|
|
|
|
local iproc = require 'iproc'
|
|
|
|
local data_augmentation = require 'data_augmentation'
|
|
|
|
local pairwise_transform_utils = {}
|
|
|
|
|
|
|
|
function pairwise_transform_utils.random_half(src, p, filters)
|
|
|
|
if torch.uniform() < p then
|
|
|
|
local filter = filters[torch.random(1, #filters)]
|
|
|
|
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
|
|
|
else
|
|
|
|
return src
|
|
|
|
end
|
|
|
|
end
|
2016-06-06 17:04:13 +12:00
|
|
|
function pairwise_transform_utils.crop_if_large(src, max_size, mod)
|
2016-05-13 12:49:53 +12:00
|
|
|
local tries = 4
|
|
|
|
if src:size(2) > max_size and src:size(3) > max_size then
|
2016-06-06 17:04:13 +12:00
|
|
|
assert(max_size % 4 == 0)
|
2016-05-13 12:49:53 +12:00
|
|
|
local rect
|
|
|
|
for i = 1, tries do
|
|
|
|
local yi = torch.random(0, src:size(2) - max_size)
|
|
|
|
local xi = torch.random(0, src:size(3) - max_size)
|
2016-06-06 17:04:13 +12:00
|
|
|
if mod then
|
|
|
|
yi = yi - (yi % mod)
|
|
|
|
xi = xi - (xi % mod)
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
|
|
|
-- ignore simple background
|
|
|
|
if rect:float():std() >= 0 then
|
|
|
|
break
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return rect
|
|
|
|
else
|
|
|
|
return src
|
|
|
|
end
|
|
|
|
end
|
|
|
|
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|
|
|
local dest = src
|
2016-06-06 17:04:13 +12:00
|
|
|
local box_only = false
|
|
|
|
if options.data.filters then
|
|
|
|
if #options.data.filters == 1 and options.data.filters[1] == "Box" then
|
|
|
|
box_only = true
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if box_only then
|
|
|
|
local mod = 2 -- assert pos % 2 == 0
|
|
|
|
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
|
|
|
|
dest = data_augmentation.flip(dest)
|
|
|
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
|
|
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
|
|
|
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
2016-06-22 03:37:29 +12:00
|
|
|
dest = iproc.crop_mod4(dest)
|
2016-06-06 17:04:13 +12:00
|
|
|
else
|
|
|
|
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
|
|
|
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
|
|
|
dest = data_augmentation.flip(dest)
|
|
|
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
|
|
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
|
|
|
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
|
|
|
dest = data_augmentation.shift_1px(dest)
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
return dest
|
|
|
|
end
|
2016-05-30 22:15:54 +12:00
|
|
|
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
2016-05-13 12:49:53 +12:00
|
|
|
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
|
|
|
assert("crop_size % scale == 0", size % scale == 0)
|
|
|
|
local r = torch.uniform()
|
|
|
|
local t = "float"
|
|
|
|
if x:type() == "torch.ByteTensor" then
|
|
|
|
t = "byte"
|
|
|
|
end
|
|
|
|
if p < r then
|
2016-06-02 13:12:04 +12:00
|
|
|
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
|
|
|
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
|
|
|
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
|
|
|
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
2016-05-13 12:49:53 +12:00
|
|
|
return xc, yc
|
|
|
|
else
|
|
|
|
local best_se = 0.0
|
|
|
|
local best_xi, best_yi
|
2016-06-17 23:40:03 +12:00
|
|
|
local m = torch.LongTensor(y:size(1), size, size)
|
|
|
|
local targets = {}
|
2016-05-13 12:49:53 +12:00
|
|
|
for i = 1, tries do
|
2016-06-02 13:12:04 +12:00
|
|
|
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
|
|
|
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
2016-06-17 23:40:03 +12:00
|
|
|
local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
|
|
|
|
local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
|
|
|
|
m:copy(xc:long()):csub(lc:long())
|
|
|
|
m:cmul(m)
|
|
|
|
local se = m:sum()
|
2016-05-13 12:49:53 +12:00
|
|
|
if se >= best_se then
|
|
|
|
best_xi = xi
|
|
|
|
best_yi = yi
|
|
|
|
best_se = se
|
|
|
|
end
|
|
|
|
end
|
|
|
|
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
|
|
|
|
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
|
|
|
|
return xc, yc
|
|
|
|
end
|
|
|
|
end
|
2016-06-18 23:25:15 +12:00
|
|
|
function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
|
|
|
|
local xs = {}
|
|
|
|
local ns = {}
|
|
|
|
local ys = {}
|
|
|
|
local ls = {}
|
|
|
|
|
|
|
|
for j = 1, 2 do
|
|
|
|
-- TTA
|
|
|
|
local xi, yi, ri
|
|
|
|
if j == 1 then
|
|
|
|
xi = x
|
|
|
|
ni = x_noise
|
|
|
|
yi = y
|
|
|
|
ri = lowres_y
|
|
|
|
else
|
|
|
|
xi = x:transpose(2, 3):contiguous()
|
|
|
|
if x_noise then
|
|
|
|
ni = x_noise:transpose(2, 3):contiguous()
|
|
|
|
end
|
|
|
|
yi = y:transpose(2, 3):contiguous()
|
|
|
|
ri = lowres_y:transpose(2, 3):contiguous()
|
|
|
|
end
|
|
|
|
local xv = image.vflip(xi)
|
|
|
|
local nv
|
|
|
|
if x_noise then
|
|
|
|
nv = image.vflip(ni)
|
|
|
|
end
|
|
|
|
local yv = image.vflip(yi)
|
|
|
|
local rv = image.vflip(ri)
|
|
|
|
table.insert(xs, xi)
|
|
|
|
if ni then
|
|
|
|
table.insert(ns, ni)
|
|
|
|
end
|
|
|
|
table.insert(ys, yi)
|
|
|
|
table.insert(ls, ri)
|
|
|
|
|
|
|
|
table.insert(xs, xv)
|
|
|
|
if nv then
|
|
|
|
table.insert(ns, nv)
|
|
|
|
end
|
|
|
|
table.insert(ys, yv)
|
|
|
|
table.insert(ls, rv)
|
|
|
|
|
|
|
|
table.insert(xs, image.hflip(xi))
|
|
|
|
if ni then
|
|
|
|
table.insert(ns, image.hflip(ni))
|
|
|
|
end
|
|
|
|
table.insert(ys, image.hflip(yi))
|
|
|
|
table.insert(ls, image.hflip(ri))
|
|
|
|
|
|
|
|
table.insert(xs, image.hflip(xv))
|
|
|
|
if nv then
|
|
|
|
table.insert(ns, image.hflip(nv))
|
|
|
|
end
|
|
|
|
table.insert(ys, image.hflip(yv))
|
|
|
|
table.insert(ls, image.hflip(rv))
|
|
|
|
end
|
|
|
|
return xs, ys, ls, ns
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
|
|
|
|
return pairwise_transform_utils
|