2016-05-13 12:49:53 +12:00
|
|
|
require 'image'
|
|
|
|
local iproc = require 'iproc'
|
|
|
|
local data_augmentation = require 'data_augmentation'
|
|
|
|
local pairwise_transform_utils = {}
|
|
|
|
|
|
|
|
function pairwise_transform_utils.random_half(src, p, filters)
|
|
|
|
if torch.uniform() < p then
|
|
|
|
local filter = filters[torch.random(1, #filters)]
|
|
|
|
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
|
|
|
else
|
|
|
|
return src
|
|
|
|
end
|
|
|
|
end
|
2016-06-06 17:04:13 +12:00
|
|
|
function pairwise_transform_utils.crop_if_large(src, max_size, mod)
|
2016-05-13 12:49:53 +12:00
|
|
|
local tries = 4
|
|
|
|
if src:size(2) > max_size and src:size(3) > max_size then
|
2016-06-06 17:04:13 +12:00
|
|
|
assert(max_size % 4 == 0)
|
2016-05-13 12:49:53 +12:00
|
|
|
local rect
|
|
|
|
for i = 1, tries do
|
|
|
|
local yi = torch.random(0, src:size(2) - max_size)
|
|
|
|
local xi = torch.random(0, src:size(3) - max_size)
|
2016-06-06 17:04:13 +12:00
|
|
|
if mod then
|
|
|
|
yi = yi - (yi % mod)
|
|
|
|
xi = xi - (xi % mod)
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
|
|
|
-- ignore simple background
|
|
|
|
if rect:float():std() >= 0 then
|
|
|
|
break
|
|
|
|
end
|
|
|
|
end
|
|
|
|
return rect
|
|
|
|
else
|
|
|
|
return src
|
|
|
|
end
|
|
|
|
end
|
|
|
|
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
|
|
|
local dest = src
|
2016-06-06 17:04:13 +12:00
|
|
|
local box_only = false
|
|
|
|
if options.data.filters then
|
|
|
|
if #options.data.filters == 1 and options.data.filters[1] == "Box" then
|
|
|
|
box_only = true
|
|
|
|
end
|
|
|
|
end
|
|
|
|
if box_only then
|
|
|
|
local mod = 2 -- assert pos % 2 == 0
|
|
|
|
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
|
|
|
|
dest = data_augmentation.flip(dest)
|
|
|
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
|
|
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
|
|
|
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
|
|
|
else
|
|
|
|
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
|
|
|
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
|
|
|
dest = data_augmentation.flip(dest)
|
|
|
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
|
|
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
|
|
|
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
|
|
|
dest = data_augmentation.shift_1px(dest)
|
|
|
|
end
|
2016-05-13 12:49:53 +12:00
|
|
|
return dest
|
|
|
|
end
|
2016-05-30 22:15:54 +12:00
|
|
|
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
|
2016-05-13 12:49:53 +12:00
|
|
|
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
|
|
|
assert("crop_size % scale == 0", size % scale == 0)
|
|
|
|
local r = torch.uniform()
|
|
|
|
local t = "float"
|
|
|
|
if x:type() == "torch.ByteTensor" then
|
|
|
|
t = "byte"
|
|
|
|
end
|
|
|
|
if p < r then
|
2016-06-02 13:12:04 +12:00
|
|
|
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
|
|
|
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
|
|
|
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
|
|
|
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
|
2016-05-13 12:49:53 +12:00
|
|
|
return xc, yc
|
|
|
|
else
|
|
|
|
local best_se = 0.0
|
|
|
|
local best_xi, best_yi
|
|
|
|
local m = torch.FloatTensor(y:size(1), size, size)
|
|
|
|
for i = 1, tries do
|
2016-06-02 13:12:04 +12:00
|
|
|
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
|
|
|
|
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
|
2016-05-13 12:49:53 +12:00
|
|
|
local xc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
2016-05-30 22:15:54 +12:00
|
|
|
local lc = iproc.crop(lowres_y, xi, yi, xi + size, yi + size)
|
2016-05-13 12:49:53 +12:00
|
|
|
local xcf = iproc.byte2float(xc)
|
|
|
|
local lcf = iproc.byte2float(lc)
|
|
|
|
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
|
|
|
|
if se >= best_se then
|
|
|
|
best_xi = xi
|
|
|
|
best_yi = yi
|
|
|
|
best_se = se
|
|
|
|
end
|
|
|
|
end
|
|
|
|
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
|
|
|
|
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
|
|
|
|
return xc, yc
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
return pairwise_transform_utils
|