1
0
Fork 0
mirror of synced 2024-05-18 11:52:17 +12:00

Add new models

upconv_7 is 2.3x faster than previous model
This commit is contained in:
nagadomi 2016-05-13 09:49:53 +09:00
parent e62305377f
commit 51ae485cd1
8 changed files with 608 additions and 320 deletions

View file

@ -1,255 +1,9 @@
require 'image'
local gm = require 'graphicsmagick'
local iproc = require 'iproc'
local data_augmentation = require 'data_augmentation'
require 'pl'
local pairwise_transform = {}
local function random_half(src, p, filters)
if torch.uniform() < p then
local filter = filters[torch.random(1, #filters)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
local function crop_if_large(src, max_size)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
end
local function preprocess(src, crop_size, options)
local dest = src
dest = random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
return dest
end
local function active_cropping(x, y, size, p, tries)
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
return xc, yc
else
local lowres = gm.Image(x, "RGB", "DHW"):
size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
size(x:size(3), x:size(2), "Box"):
toTensor(t, "RGB", "DHW")
local best_se = 0.0
local best_xc, best_yc
local m = torch.FloatTensor(x:size(1), size, size)
for i = 1, tries do
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local lcf = iproc.byte2float(lc)
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
if se >= best_se then
best_xc = xcf
best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
best_se = se
end
end
return best_xc, best_yc
end
end
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = options.downsampling_filters
local unstable_region_offset = 8
local downsampling_filter = filters[torch.random(1, #filters)]
local y = preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
local down_scale = 1.0 / scale
local x
if options.gamma_correction then
x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter),
y:size(3), y:size(2), options.upsampling_filter)
else
x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter),
y:size(3), y:size(2), options.upsampling_filter)
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y,
size,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
return batch
end
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
local unstable_region_offset = 8
local y = preprocess(src, size, options)
local x = y
pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_scale'))
pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_jpeg'))
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg"):depth(8)
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-- YUV 420
x:samplingFactors({2.0, 1.0, 1.0})
else
-- YUV 444
x:samplingFactors({1.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y, size,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
if torch.uniform() < options.nr_rate then
-- reducing noise
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
else
-- ratain useful details
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
end
return batch
end
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if style == "art" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset, n, options)
elseif level == 2 or level == 3 then
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset, n, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_(src,
{quality1, quality2, quality3},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
elseif style == "photo" then
-- level adjusting by -nr_rate
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
size, offset, n,
options)
else
error("unknown style: " .. style)
end
end
print(pairwise_transform)
function pairwise_transform.test_jpeg(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
jpeg_chroma_subsampling_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 9 do
local xy = pairwise_transform.jpeg(src,
"art",
torch.random(1, 2),
128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
end
end
function pairwise_transform.test_scale(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 10 do
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
end
end
return pairwise_transform

View file

@ -0,0 +1,117 @@
local pairwise_utils = require 'pairwise_transform_utils'
local gm = require 'graphicsmagick'
local iproc = require 'iproc'
local pairwise_transform = {}
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
local unstable_region_offset = 8
local y = pairwise_utils.preprocess(src, size, options)
local x = y
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg"):depth(8)
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-- YUV 420
x:samplingFactors({2.0, 1.0, 1.0})
else
-- YUV 444
x:samplingFactors({1.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = pairwise_utils.active_cropping(x, y, size, 1,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
if torch.uniform() < options.nr_rate then
-- reducing noise
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
else
-- ratain useful details
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
end
return batch
end
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if style == "art" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset, n, options)
elseif level == 2 or level == 3 then
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset, n, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_(src,
{quality1, quality2, quality3},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
elseif style == "photo" then
-- level adjusting by -nr_rate
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
size, offset, n,
options)
else
error("unknown style: " .. style)
end
end
function pairwise_transform.test_jpeg(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
jpeg_chroma_subsampling_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 9 do
local xy = pairwise_transform.jpeg(src,
"art",
torch.random(1, 2),
128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
end
end
return pairwise_transform

View file

@ -0,0 +1,86 @@
local pairwise_utils = require 'pairwise_transform_utils'
local iproc = require 'iproc'
local pairwise_transform = {}
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = options.downsampling_filters
local unstable_region_offset = 8
local downsampling_filter = filters[torch.random(1, #filters)]
local y = pairwise_utils.preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
local down_scale = 1.0 / scale
local x
if options.gamma_correction then
local small = iproc.scale_with_gamma22(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter)
if options.x_upsampling then
x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
else
x = small
end
else
local small = iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter)
if options.x_upsampling then
x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
else
x = small
end
end
if options.x_upsampling then
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
else
assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
end
local scale_inner = scale
if options.x_upsampling then
scale_inner = 1
end
local batch = {}
for i = 1, n do
local xc, yc = pairwise_utils.active_cropping(x, y,
size,
scale_inner,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
return batch
end
function pairwise_transform.test_scale(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
x_upsampling = false,
downsampling_filters = "Box",
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 10 do
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
end
end
return pairwise_transform

View file

@ -0,0 +1,91 @@
require 'image'
local gm = require 'graphicsmagick'
local iproc = require 'iproc'
local data_augmentation = require 'data_augmentation'
local pairwise_transform_utils = {}
function pairwise_transform_utils.random_half(src, p, filters)
if torch.uniform() < p then
local filter = filters[torch.random(1, #filters)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
function pairwise_transform_utils.crop_if_large(src, max_size)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
end
function pairwise_transform_utils.preprocess(src, crop_size, options)
local dest = src
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
return dest
end
function pairwise_transform_utils.active_cropping(x, y, size, scale, p, tries)
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
assert("crop_size % scale == 0", size % scale == 0)
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = torch.random(0, x:size(3) - (size + 1))
local yi = torch.random(0, x:size(2) - (size + 1))
local yc = iproc.crop(y, xi * scale, yi * scale, xi * scale + size, yi * scale + size)
local xc = iproc.crop(x, xi, yi, xi + size / scale, yi + size / scale)
return xc, yc
else
local test_scale = 2
if test_scale < scale then
test_scale = scale
end
local lowres = gm.Image(y, "RGB", "DHW"):
size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
size(y:size(3), y:size(2), "Box"):
toTensor(t, "RGB", "DHW")
local best_se = 0.0
local best_xi, best_yi
local m = torch.FloatTensor(y:size(1), size, size)
for i = 1, tries do
local xi = torch.random(0, x:size(3) - (size + 1)) * scale
local yi = torch.random(0, x:size(2) - (size + 1)) * scale
local xc = iproc.crop(y, xi, yi, xi + size, yi + size)
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local lcf = iproc.byte2float(lc)
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
if se >= best_se then
best_xi = xi
best_yi = yi
best_se = se
end
end
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
return xc, yc
end
end
return pairwise_transform_utils

View file

@ -49,6 +49,32 @@ local function reconstruct_rgb(model, x, offset, block_size)
end
return new_x
end
local function reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
local new_x = torch.Tensor(x:size(1), x:size(2) * scale, x:size(3) * scale):zero()
local input_block_size = block_size / scale
local output_block_size = block_size
local output_size = output_block_size - offset * 2
local output_size_in_input = input_block_size - offset
local input = torch.CudaTensor(1, 3, input_block_size, input_block_size)
for i = 1, x:size(2), output_size_in_input do
for j = 1, new_x:size(3), output_size_in_input do
if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
local index = {{},
{i, i + input_block_size - 1},
{j, j + input_block_size - 1}}
input:copy(x[index])
local output = model:forward(input):view(3, output_size, output_size)
local ii = (i - 1) * scale + 1
local jj = (j - 1) * scale + 1
local output_index = {{}, { ii , ii + output_size - 1 },
{ jj, jj + output_size - 1}}
new_x[output_index]:copy(output)
end
end
end
return new_x
end
local reconstruct = {}
function reconstruct.is_rgb(model)
if srcnn.channels(model) == 3 then
@ -62,6 +88,9 @@ end
function reconstruct.offset_size(model)
return srcnn.offset_size(model)
end
function reconstruct.no_resize(model)
return srcnn.has_resize(model)
end
function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128
local output_size = block_size - offset * 2
@ -95,8 +124,14 @@ end
function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128
local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
local x_lanczos
if reconstruct.no_resize(model) then
x_lanczos = x:clone()
else
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
end
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
@ -162,39 +197,77 @@ function reconstruct.image_rgb(model, x, offset, block_size)
return output
end
function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
((x:size(2) % output_size == 0 and 0) or 1)
local w_blocks = math.floor(x:size(3) / output_size) +
((x:size(3) % output_size == 0 and 0) or 1)
local h = offset + h_blocks * output_size + offset
local w = offset + w_blocks * output_size + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
local y = reconstruct_rgb(model, x, offset, block_size)
local output = iproc.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
x = nil
y = nil
collectgarbage()
if reconstruct.no_resize(model) then
block_size = block_size or 128
local input_block_size = block_size / scale
local x_w = x:size(3)
local x_h = x:size(2)
local process_size = input_block_size - offset * 2
-- TODO: under construction!! bug in 4x
local h_blocks = math.floor(x_h / process_size) + 2
-- ((x_h % process_size == 0 and 0) or 1)
local w_blocks = math.floor(x_w / process_size) + 2
-- ((x_w % process_size == 0 and 0) or 1)
local h = offset + (h_blocks * process_size) + offset
local w = offset + (w_blocks * process_size) + offset
local pad_h1 = offset
local pad_w1 = offset
return output
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
local y
y = reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
local output = iproc.crop(y,
pad_w1, pad_h1,
pad_w1 + x_w * scale, pad_h1 + x_h * scale)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
x = nil
y = nil
collectgarbage()
return output
else
upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
((x:size(2) % output_size == 0 and 0) or 1)
local w_blocks = math.floor(x:size(3) / output_size) +
((x:size(3) % output_size == 0 and 0) or 1)
local h = offset + h_blocks * output_size + offset
local w = offset + w_blocks * output_size + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
local y
y = reconstruct_rgb(model, x, offset, block_size)
local output = iproc.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
x = nil
y = nil
collectgarbage()
return output
end
end
function reconstruct.image(model, x, block_size)

View file

@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'path to test image')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)')
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
cmd:option("-noise_level", 1, '(1|2|3)')
cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)')
@ -34,7 +34,7 @@ cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution im
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.0005, 'learning rate for adam')
cmd:option("-crop_size", 46, 'crop size')
cmd:option("-crop_size", 48, 'crop size')
cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
cmd:option("-batch_size", 8, 'mini batch size')
cmd:option("-patches", 16, 'number of patch samples')

View file

@ -9,14 +9,23 @@ function nn.SpatialConvolutionMM:reset(stdv)
self.weight:normal(0, stdv)
self.bias:zero()
end
function nn.SpatialFullConvolution:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:zero()
end
if cudnn and cudnn.SpatialConvolution then
function cudnn.SpatialConvolution:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:zero()
end
function cudnn.SpatialFullConvolution:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:zero()
end
end
function nn.SpatialConvolutionMM:clearState()
if self.gradWeight then
self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
@ -26,9 +35,12 @@ function nn.SpatialConvolutionMM:clearState()
end
return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
end
function srcnn.channels(model)
return model:get(model:size() - 1).weight:size(1)
if model.w2nn_channels ~= nil then
return model.w2nn_channels
else
return model:get(model:size() - 1).weight:size(1)
end
end
function srcnn.backend(model)
local conv = model:findModules("cudnn.SpatialConvolution")
@ -47,32 +59,54 @@ function srcnn.color(model)
end
end
function srcnn.name(model)
local backend_cudnn = false
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then
backend_cudnn = true
conv = model:findModules("cudnn.SpatialConvolution")
end
if #conv == 7 then
return "vgg_7"
elseif #conv == 12 then
return "vgg_12"
if model.w2nn_arch_name then
return model.w2nn_arch_name
else
return nil
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then
conv = model:findModules("cudnn.SpatialConvolution")
end
if #conv == 7 then
return "vgg_7"
elseif #conv == 12 then
return "vgg_12"
else
error("unsupported model name")
end
end
end
function srcnn.offset_size(model)
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then
conv = model:findModules("cudnn.SpatialConvolution")
if model.w2nn_offset ~= nil then
return model.w2nn_offset
else
local name = srcnn.name(model)
if name:match("vgg_") then
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then
conv = model:findModules("cudnn.SpatialConvolution")
end
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
else
error("unsupported model name")
end
end
end
function srcnn.has_resize(model)
if model.w2nn_resize ~= nil then
return model.w2nn_resize
else
local name = srcnn.name(model)
if name:match("upconv") ~= nil then
return true
else
return false
end
end
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
end
local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
@ -82,6 +116,15 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
error("unsupported backend:" .. backend)
end
end
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)
@ -100,6 +143,11 @@ function srcnn.vgg_7(backend, ch)
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "vgg_7"
model.w2nn_offset = 7
model.w2nn_resize = false
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
@ -132,12 +180,103 @@ function srcnn.vgg_12(backend, ch)
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "vgg_12"
model.w2nn_offset = 12
model.w2nn_resize = false
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- Dilated Convolution (7 layers)
function srcnn.dilated_7(backend, ch)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "dilated_7"
model.w2nn_offset = 12
model.w2nn_resize = false
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- Up Convolution
function srcnn.upconv_7(backend, ch)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
model.w2nn_arch_name = "upconv_7"
model.w2nn_offset = 12
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.upconv_8_4x(backend, ch)
local model = nn.Sequential()
model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
model.w2nn_arch_name = "upconv_8_4x"
model.w2nn_offset = 12
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.create(model_name, backend, color)
model_name = model_name or "vgg_7"
backend = backend or "cunn"
@ -150,12 +289,14 @@ function srcnn.create(model_name, backend, color)
else
error("unsupported color: " .. color)
end
if model_name == "vgg_7" then
return srcnn.vgg_7(backend, ch)
elseif model_name == "vgg_12" then
return srcnn.vgg_12(backend, ch)
if srcnn[model_name] then
return srcnn[model_name](backend, ch)
else
error("unsupported model_name: " .. model_name)
end
end
--local model = srcnn.upconv_8_4x("cunn", 3):cuda()
--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
return srcnn

View file

@ -15,7 +15,9 @@ local pairwise_transform = require 'pairwise_transform'
local image_loader = require 'image_loader'
local function save_test_scale(model, rgb, file)
local up = reconstruct.scale(model, settings.scale, rgb, 128, settings.upsampling_filter)
local up = reconstruct.scale(model, settings.scale, rgb,
settings.scale * settings.crop_size,
settings.upsampling_filter)
image.save(file, up)
end
local function save_test_jpeg(model, rgb, file)
@ -96,6 +98,7 @@ local function create_criterion(model)
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
@ -108,7 +111,7 @@ local function create_criterion(model)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
end
end
local function transformer(x, is_validation, n, offset)
local function transformer(model, x, is_validation, n, offset)
x = compression.decompress(x)
n = n or settings.patches
@ -145,7 +148,8 @@ local function transformer(x, is_validation, n, offset)
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb"),
gamma_correction = settings.gamma_correction
gamma_correction = settings.gamma_correction,
x_upsampling = not srcnn.has_resize(model)
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
@ -183,6 +187,22 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
end
end
end
local function remove_small_image(x)
local new_x = {}
for i = 1, #x do
local x_s = compression.size(x[i])
if x_s[2] / settings.scale > settings.crop_size + 16 and
x_s[3] / settings.scale > settings.crop_size + 16 then
table.insert(new_x, x[i])
end
if i % 100 == 0 then
collectgarbage()
end
end
print(string.format("removed %d small images", #x - #new_x))
return new_x
end
local function plot(train, valid)
gnuplot.plot({
{'training', torch.Tensor(train), '-'},
@ -194,11 +214,11 @@ local function train()
local model = srcnn.create(settings.model, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n)
return transformer(x, is_validation, n, offset)
return transformer(model, x, is_validation, n, offset)
end
local criterion = create_criterion(model)
local eval_metric = nn.MSECriterion():cuda()
local x = torch.load(settings.images)
local x = remove_small_image(torch.load(settings.images))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local adam_config = {
learningRate = settings.learning_rate,
@ -222,10 +242,16 @@ local function train()
model:cuda()
print("load .. " .. #train_x)
local x = torch.Tensor(settings.patches * #train_x,
ch, settings.crop_size, settings.crop_size)
local x = nil
local y = torch.Tensor(settings.patches * #train_x,
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
if srcnn.has_resize(model) then
x = torch.Tensor(settings.patches * #train_x,
ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale)
else
x = torch.Tensor(settings.patches * #train_x,
ch, settings.crop_size, settings.crop_size)
end
for epoch = 1, settings.epoch do
model:training()
print("# " .. epoch)