1
0
Fork 0
mirror of synced 2024-06-02 19:14:30 +12:00
waifu2x/lib/pairwise_transform_utils.lua

163 lines
5.1 KiB
Lua

require 'image'
local iproc = require 'iproc'
local data_augmentation = require 'data_augmentation'
local pairwise_transform_utils = {}
function pairwise_transform_utils.random_half(src, p, filters)
if torch.uniform() < p then
local filter = filters[torch.random(1, #filters)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
function pairwise_transform_utils.crop_if_large(src, max_size, mod)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
assert(max_size % 4 == 0)
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
if mod then
yi = yi - (yi % mod)
xi = xi - (xi % mod)
end
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
end
function pairwise_transform_utils.preprocess(src, crop_size, options)
local dest = src
local box_only = false
if options.data.filters then
if #options.data.filters == 1 and options.data.filters[1] == "Box" then
box_only = true
end
end
if box_only then
local mod = 2 -- assert pos % 2 == 0
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = iproc.crop_mod4(dest)
else
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
end
return dest
end
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
assert("crop_size % scale == 0", size % scale == 0)
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
return xc, yc
else
local best_se = 0.0
local best_xi, best_yi
local m = torch.LongTensor(y:size(1), size, size)
local targets = {}
for i = 1, tries do
local xi = torch.random(1, x:size(3) - (size + 1)) * scale
local yi = torch.random(1, x:size(2) - (size + 1)) * scale
local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
m:copy(xc:long()):csub(lc:long())
m:cmul(m)
local se = m:sum()
if se >= best_se then
best_xi = xi
best_yi = yi
best_se = se
end
end
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
return xc, yc
end
end
function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
local xs = {}
local ns = {}
local ys = {}
local ls = {}
for j = 1, 2 do
-- TTA
local xi, yi, ri
if j == 1 then
xi = x
ni = x_noise
yi = y
ri = lowres_y
else
xi = x:transpose(2, 3):contiguous()
if x_noise then
ni = x_noise:transpose(2, 3):contiguous()
end
yi = y:transpose(2, 3):contiguous()
ri = lowres_y:transpose(2, 3):contiguous()
end
local xv = image.vflip(xi)
local nv
if x_noise then
nv = image.vflip(ni)
end
local yv = image.vflip(yi)
local rv = image.vflip(ri)
table.insert(xs, xi)
if ni then
table.insert(ns, ni)
end
table.insert(ys, yi)
table.insert(ls, ri)
table.insert(xs, xv)
if nv then
table.insert(ns, nv)
end
table.insert(ys, yv)
table.insert(ls, rv)
table.insert(xs, image.hflip(xi))
if ni then
table.insert(ns, image.hflip(ni))
end
table.insert(ys, image.hflip(yi))
table.insert(ls, image.hflip(ri))
table.insert(xs, image.hflip(xv))
if nv then
table.insert(ns, image.hflip(nv))
end
table.insert(ys, image.hflip(yv))
table.insert(ls, image.hflip(rv))
end
return xs, ys, ls, ns
end
return pairwise_transform_utils