1
0
Fork 0
mirror of synced 2024-06-01 18:49:33 +12:00
waifu2x/lib/pairwise_transform_utils.lua
2017-02-19 01:42:40 +09:00

289 lines
8.9 KiB
Lua

require 'cunn'
local iproc = require 'iproc'
local gm = {}
gm.Image = require 'graphicsmagick.Image'
local data_augmentation = require 'data_augmentation'
local pairwise_transform_utils = {}
function pairwise_transform_utils.random_half(src, p, filters)
if torch.uniform() < p then
local filter = filters[torch.random(1, #filters)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
function pairwise_transform_utils.crop_if_large(src, max_size, mod)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
assert(max_size % 4 == 0)
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
if mod then
yi = yi - (yi % mod)
xi = xi - (xi % mod)
end
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
end
function pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, max_size, mod)
local tries = 4
if y:size(2) > max_size and y:size(3) > max_size then
assert(max_size % 4 == 0)
local rect_x, rect_y
for i = 1, tries do
local yi = torch.random(0, y:size(2) - max_size)
local xi = torch.random(0, y:size(3) - max_size)
if mod then
yi = yi - (yi % mod)
xi = xi - (xi % mod)
end
rect_y = iproc.crop(y, xi, yi, xi + max_size, yi + max_size)
rect_x = iproc.crop(x, xi / scale_y, yi / scale_y, xi / scale_y + max_size / scale_y, yi / scale_y + max_size / scale_y)
-- ignore simple background
if rect_y:float():std() >= 0 then
break
end
end
return rect_x, rect_y
else
return x, y
end
end
function pairwise_transform_utils.preprocess(src, crop_size, options)
local dest = src
local box_only = false
if options.data.filters then
if #options.data.filters == 1 and options.data.filters[1] == "Box" then
box_only = true
end
end
if box_only then
local mod = 2 -- assert pos % 2 == 0
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size), mod)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = iproc.crop_mod4(dest)
else
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.blur(dest, options.random_blur_rate,
options.random_blur_size,
options.random_blur_sigma_min,
options.random_blur_sigma_max)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
end
return dest
end
function pairwise_transform_utils.preprocess_user(x, y, scale_y, size, options)
x, y = pairwise_transform_utils.crop_if_large_pair(x, y, scale_y, options.max_size, scale_y)
x, y = data_augmentation.pairwise_rotate(x, y,
options.random_pairwise_rotate_rate,
options.random_pairwise_rotate_min,
options.random_pairwise_rotate_max)
local scale_min = math.max(options.random_pairwise_scale_min, size / (1 + math.min(x:size(2), x:size(3))))
local scale_max = math.max(scale_min, options.random_pairwise_scale_max)
x, y = data_augmentation.pairwise_scale(x, y,
options.random_pairwise_scale_rate,
scale_min,
scale_max)
x, y = data_augmentation.pairwise_negate(x, y, options.random_pairwise_negate_rate)
x, y = data_augmentation.pairwise_negate_x(x, y, options.random_pairwise_negate_x_rate)
x = iproc.crop_mod4(x)
y = iproc.crop_mod4(y)
if options.pairwise_y_binary then
y[torch.lt(y, 128)] = 0
y[torch.gt(y, 0)] = 255
end
return x, y
end
function pairwise_transform_utils.active_cropping(x, y, lowres_y, size, scale, p, tries)
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
assert("crop_size % scale == 0", size % scale == 0)
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = 0
local yi = 0
if x:size(3) > size + 1 then
xi = torch.random(0, x:size(3) - (size + 1)) * scale
end
if x:size(2) > size + 1 then
yi = torch.random(0, x:size(2) - (size + 1)) * scale
end
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xc = iproc.crop(x, xi / scale, yi / scale, xi / scale + size / scale, yi / scale + size / scale)
return xc, yc
else
local xcs = torch.LongTensor(tries, y:size(1), size, size)
local lcs = torch.LongTensor(tries, lowres_y:size(1), size, size)
local rects = {}
local r = torch.LongTensor(2, tries)
r[1]:random(1, x:size(3) - (size + 1)):mul(scale)
r[2]:random(1, x:size(2) - (size + 1)):mul(scale)
for i = 1, tries do
local xi = r[1][i]
local yi = r[2][i]
local xc = iproc.crop_nocopy(y, xi, yi, xi + size, yi + size)
local lc = iproc.crop_nocopy(lowres_y, xi, yi, xi + size, yi + size)
xcs[i]:copy(xc)
lcs[i]:copy(lc)
rects[i] = {xi, yi}
end
xcs:csub(lcs)
xcs:cmul(xcs)
local v, l = xcs:reshape(xcs:size(1), xcs:nElement() / xcs:size(1)):transpose(1, 2):sum(1):topk(1, true)
local best_xi = rects[l[1][1]][1]
local best_yi = rects[l[1][1]][2]
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
return xc, yc
end
end
function pairwise_transform_utils.flip_augmentation(x, y, lowres_y, x_noise)
local xs = {}
local ns = {}
local ys = {}
local ls = {}
for j = 1, 2 do
-- TTA
local xi, yi, ri, ni
if j == 1 then
xi = x
ni = x_noise
yi = y
ri = lowres_y
else
xi = x:transpose(2, 3):contiguous()
if x_noise then
ni = x_noise:transpose(2, 3):contiguous()
end
yi = y:transpose(2, 3):contiguous()
if lowres_y then
ri = lowres_y:transpose(2, 3):contiguous()
end
end
local xv = iproc.vflip(xi)
local nv
if x_noise then
nv = iproc.vflip(ni)
end
local yv = iproc.vflip(yi)
local rv
if ri then
rv = iproc.vflip(ri)
end
table.insert(xs, xi)
if ni then
table.insert(ns, ni)
end
table.insert(ys, yi)
if ri then
table.insert(ls, ri)
end
table.insert(xs, xv)
if nv then
table.insert(ns, nv)
end
table.insert(ys, yv)
if rv then
table.insert(ls, rv)
end
table.insert(xs, iproc.hflip(xi))
if ni then
table.insert(ns, iproc.hflip(ni))
end
table.insert(ys, iproc.hflip(yi))
if ri then
table.insert(ls, iproc.hflip(ri))
end
table.insert(xs, iproc.hflip(xv))
if nv then
table.insert(ns, iproc.hflip(nv))
end
table.insert(ys, iproc.hflip(yv))
if rv then
table.insert(ls, iproc.hflip(rv))
end
end
return xs, ys, ls, ns
end
local function lowres_model()
local seq = nn.Sequential()
seq:add(nn.SpatialAveragePooling(2, 2, 2, 2))
seq:add(nn.SpatialUpSamplingNearest(2))
return seq:cuda()
end
local g_lowres_model = nil
local g_lowres_gpu = nil
function pairwise_transform_utils.low_resolution(src)
--[[
-- I am not sure that the following process is thraed-safe
g_lowres_model = g_lowres_model or lowres_model()
if g_lowres_gpu == nil then
--benchmark
local gpu_time = sys.clock()
for i = 1, 10 do
g_lowres_model:forward(src:cuda()):byte()
end
gpu_time = sys.clock() - gpu_time
local cpu_time = sys.clock()
for i = 1, 10 do
gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
cpu_time = sys.clock() - cpu_time
--print(gpu_time, cpu_time)
if gpu_time < cpu_time then
g_lowres_gpu = true
else
g_lowres_gpu = false
end
end
if g_lowres_gpu then
return g_lowres_model:forward(src:cuda()):byte()
else
return gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
--]]
return gm.Image(src, "RGB", "DHW"):
size(src:size(3) * 0.5, src:size(2) * 0.5, "Box"):
size(src:size(3), src:size(2), "Box"):
toTensor("byte", "RGB", "DHW")
end
return pairwise_transform_utils