1
0
Fork 0
mirror of synced 2024-06-26 10:10:49 +12:00

Add new models

upconv_7 is 2.3x faster than previous model
This commit is contained in:
nagadomi 2016-05-13 09:49:53 +09:00
parent e62305377f
commit 51ae485cd1
8 changed files with 608 additions and 320 deletions

View file

@ -1,255 +1,9 @@
require 'image' require 'pl'
local gm = require 'graphicsmagick'
local iproc = require 'iproc'
local data_augmentation = require 'data_augmentation'
local pairwise_transform = {} local pairwise_transform = {}
local function random_half(src, p, filters) pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_scale'))
if torch.uniform() < p then pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_jpeg'))
local filter = filters[torch.random(1, #filters)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
local function crop_if_large(src, max_size)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
end
local function preprocess(src, crop_size, options)
local dest = src
dest = random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
return dest print(pairwise_transform)
end
local function active_cropping(x, y, size, p, tries)
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
return xc, yc
else
local lowres = gm.Image(x, "RGB", "DHW"):
size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
size(x:size(3), x:size(2), "Box"):
toTensor(t, "RGB", "DHW")
local best_se = 0.0
local best_xc, best_yc
local m = torch.FloatTensor(x:size(1), size, size)
for i = 1, tries do
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local lcf = iproc.byte2float(lc)
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
if se >= best_se then
best_xc = xcf
best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
best_se = se
end
end
return best_xc, best_yc
end
end
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = options.downsampling_filters
local unstable_region_offset = 8
local downsampling_filter = filters[torch.random(1, #filters)]
local y = preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
local down_scale = 1.0 / scale
local x
if options.gamma_correction then
x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter),
y:size(3), y:size(2), options.upsampling_filter)
else
x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter),
y:size(3), y:size(2), options.upsampling_filter)
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y,
size,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
return batch
end
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
local unstable_region_offset = 8
local y = preprocess(src, size, options)
local x = y
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg"):depth(8)
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-- YUV 420
x:samplingFactors({2.0, 1.0, 1.0})
else
-- YUV 444
x:samplingFactors({1.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y, size,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
if torch.uniform() < options.nr_rate then
-- reducing noise
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
else
-- ratain useful details
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
end
return batch
end
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if style == "art" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset, n, options)
elseif level == 2 or level == 3 then
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset, n, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_(src,
{quality1, quality2, quality3},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
elseif style == "photo" then
-- level adjusting by -nr_rate
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
size, offset, n,
options)
else
error("unknown style: " .. style)
end
end
function pairwise_transform.test_jpeg(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
jpeg_chroma_subsampling_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 9 do
local xy = pairwise_transform.jpeg(src,
"art",
torch.random(1, 2),
128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
end
end
function pairwise_transform.test_scale(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 10 do
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
end
end
return pairwise_transform return pairwise_transform

View file

@ -0,0 +1,117 @@
local pairwise_utils = require 'pairwise_transform_utils'
local gm = require 'graphicsmagick'
local iproc = require 'iproc'
local pairwise_transform = {}
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
local unstable_region_offset = 8
local y = pairwise_utils.preprocess(src, size, options)
local x = y
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg"):depth(8)
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-- YUV 420
x:samplingFactors({2.0, 1.0, 1.0})
else
-- YUV 444
x:samplingFactors({1.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = pairwise_utils.active_cropping(x, y, size, 1,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
if torch.uniform() < options.nr_rate then
-- reducing noise
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
else
-- ratain useful details
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
end
return batch
end
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if style == "art" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset, n, options)
elseif level == 2 or level == 3 then
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset, n, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_(src,
{quality1, quality2, quality3},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
elseif style == "photo" then
-- level adjusting by -nr_rate
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
size, offset, n,
options)
else
error("unknown style: " .. style)
end
end
function pairwise_transform.test_jpeg(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
jpeg_chroma_subsampling_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 9 do
local xy = pairwise_transform.jpeg(src,
"art",
torch.random(1, 2),
128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
end
end
return pairwise_transform

View file

@ -0,0 +1,86 @@
local pairwise_utils = require 'pairwise_transform_utils'
local iproc = require 'iproc'
local pairwise_transform = {}
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = options.downsampling_filters
local unstable_region_offset = 8
local downsampling_filter = filters[torch.random(1, #filters)]
local y = pairwise_utils.preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
local down_scale = 1.0 / scale
local x
if options.gamma_correction then
local small = iproc.scale_with_gamma22(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter)
if options.x_upsampling then
x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
else
x = small
end
else
local small = iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downsampling_filter)
if options.x_upsampling then
x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
else
x = small
end
end
if options.x_upsampling then
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
else
assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
end
local scale_inner = scale
if options.x_upsampling then
scale_inner = 1
end
local batch = {}
for i = 1, n do
local xc, yc = pairwise_utils.active_cropping(x, y,
size,
scale_inner,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
return batch
end
function pairwise_transform.test_scale(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
x_upsampling = false,
downsampling_filters = "Box",
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 10 do
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
end
end
return pairwise_transform

View file

@ -0,0 +1,91 @@
require 'image'
local gm = require 'graphicsmagick'
local iproc = require 'iproc'
local data_augmentation = require 'data_augmentation'
local pairwise_transform_utils = {}
function pairwise_transform_utils.random_half(src, p, filters)
if torch.uniform() < p then
local filter = filters[torch.random(1, #filters)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
function pairwise_transform_utils.crop_if_large(src, max_size)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
end
function pairwise_transform_utils.preprocess(src, crop_size, options)
local dest = src
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
return dest
end
function pairwise_transform_utils.active_cropping(x, y, size, scale, p, tries)
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
assert("crop_size % scale == 0", size % scale == 0)
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = torch.random(0, x:size(3) - (size + 1))
local yi = torch.random(0, x:size(2) - (size + 1))
local yc = iproc.crop(y, xi * scale, yi * scale, xi * scale + size, yi * scale + size)
local xc = iproc.crop(x, xi, yi, xi + size / scale, yi + size / scale)
return xc, yc
else
local test_scale = 2
if test_scale < scale then
test_scale = scale
end
local lowres = gm.Image(y, "RGB", "DHW"):
size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
size(y:size(3), y:size(2), "Box"):
toTensor(t, "RGB", "DHW")
local best_se = 0.0
local best_xi, best_yi
local m = torch.FloatTensor(y:size(1), size, size)
for i = 1, tries do
local xi = torch.random(0, x:size(3) - (size + 1)) * scale
local yi = torch.random(0, x:size(2) - (size + 1)) * scale
local xc = iproc.crop(y, xi, yi, xi + size, yi + size)
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local lcf = iproc.byte2float(lc)
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
if se >= best_se then
best_xi = xi
best_yi = yi
best_se = se
end
end
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
return xc, yc
end
end
return pairwise_transform_utils

View file

@ -49,6 +49,32 @@ local function reconstruct_rgb(model, x, offset, block_size)
end end
return new_x return new_x
end end
local function reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
local new_x = torch.Tensor(x:size(1), x:size(2) * scale, x:size(3) * scale):zero()
local input_block_size = block_size / scale
local output_block_size = block_size
local output_size = output_block_size - offset * 2
local output_size_in_input = input_block_size - offset
local input = torch.CudaTensor(1, 3, input_block_size, input_block_size)
for i = 1, x:size(2), output_size_in_input do
for j = 1, new_x:size(3), output_size_in_input do
if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
local index = {{},
{i, i + input_block_size - 1},
{j, j + input_block_size - 1}}
input:copy(x[index])
local output = model:forward(input):view(3, output_size, output_size)
local ii = (i - 1) * scale + 1
local jj = (j - 1) * scale + 1
local output_index = {{}, { ii , ii + output_size - 1 },
{ jj, jj + output_size - 1}}
new_x[output_index]:copy(output)
end
end
end
return new_x
end
local reconstruct = {} local reconstruct = {}
function reconstruct.is_rgb(model) function reconstruct.is_rgb(model)
if srcnn.channels(model) == 3 then if srcnn.channels(model) == 3 then
@ -62,6 +88,9 @@ end
function reconstruct.offset_size(model) function reconstruct.offset_size(model)
return srcnn.offset_size(model) return srcnn.offset_size(model)
end end
function reconstruct.no_resize(model)
return srcnn.has_resize(model)
end
function reconstruct.image_y(model, x, offset, block_size) function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128 block_size = block_size or 128
local output_size = block_size - offset * 2 local output_size = block_size - offset * 2
@ -95,8 +124,14 @@ end
function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter) function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
upsampling_filter = upsampling_filter or "Box" upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128 block_size = block_size or 128
local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
local x_lanczos
if reconstruct.no_resize(model) then
x_lanczos = x:clone()
else
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter) x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
end
if x:size(2) * x:size(3) > 2048*2048 then if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage() collectgarbage()
end end
@ -162,6 +197,42 @@ function reconstruct.image_rgb(model, x, offset, block_size)
return output return output
end end
function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter) function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
if reconstruct.no_resize(model) then
block_size = block_size or 128
local input_block_size = block_size / scale
local x_w = x:size(3)
local x_h = x:size(2)
local process_size = input_block_size - offset * 2
-- TODO: under construction!! bug in 4x
local h_blocks = math.floor(x_h / process_size) + 2
-- ((x_h % process_size == 0 and 0) or 1)
local w_blocks = math.floor(x_w / process_size) + 2
-- ((x_w % process_size == 0 and 0) or 1)
local h = offset + (h_blocks * process_size) + offset
local w = offset + (w_blocks * process_size) + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage()
end
local y
y = reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
local output = iproc.crop(y,
pad_w1, pad_h1,
pad_w1 + x_w * scale, pad_h1 + x_h * scale)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
x = nil
y = nil
collectgarbage()
return output
else
upsampling_filter = upsampling_filter or "Box" upsampling_filter = upsampling_filter or "Box"
block_size = block_size or 128 block_size = block_size or 128
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter) x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
@ -184,7 +255,8 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_f
if x:size(2) * x:size(3) > 2048*2048 then if x:size(2) * x:size(3) > 2048*2048 then
collectgarbage() collectgarbage()
end end
local y = reconstruct_rgb(model, x, offset, block_size) local y
y = reconstruct_rgb(model, x, offset, block_size)
local output = iproc.crop(y, local output = iproc.crop(y,
pad_w1, pad_h1, pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2) y:size(3) - pad_w2, y:size(2) - pad_h2)
@ -196,6 +268,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_f
return output return output
end end
end
function reconstruct.image(model, x, block_size) function reconstruct.image(model, x, block_size)
local i2rgb = false local i2rgb = false

View file

@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'path to test image') cmd:option("-test", "images/miku_small.png", 'path to test image')
cmd:option("-model_dir", "./models", 'model directory') cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", 'method to training (noise|scale)') cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)') cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
cmd:option("-noise_level", 1, '(1|2|3)') cmd:option("-noise_level", 1, '(1|2|3)')
cmd:option("-style", "art", '(art|photo)') cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)') cmd:option("-color", 'rgb', '(y|rgb)')
@ -34,7 +34,7 @@ cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution im
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)') cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.0005, 'learning rate for adam') cmd:option("-learning_rate", 0.0005, 'learning rate for adam')
cmd:option("-crop_size", 46, 'crop size') cmd:option("-crop_size", 48, 'crop size')
cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly') cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
cmd:option("-batch_size", 8, 'mini batch size') cmd:option("-batch_size", 8, 'mini batch size')
cmd:option("-patches", 16, 'number of patch samples') cmd:option("-patches", 16, 'number of patch samples')

View file

@ -9,14 +9,23 @@ function nn.SpatialConvolutionMM:reset(stdv)
self.weight:normal(0, stdv) self.weight:normal(0, stdv)
self.bias:zero() self.bias:zero()
end end
function nn.SpatialFullConvolution:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:zero()
end
if cudnn and cudnn.SpatialConvolution then if cudnn and cudnn.SpatialConvolution then
function cudnn.SpatialConvolution:reset(stdv) function cudnn.SpatialConvolution:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane)) stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv) self.weight:normal(0, stdv)
self.bias:zero() self.bias:zero()
end end
function cudnn.SpatialFullConvolution:reset(stdv)
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:zero()
end
end end
function nn.SpatialConvolutionMM:clearState() function nn.SpatialConvolutionMM:clearState()
if self.gradWeight then if self.gradWeight then
self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero() self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
@ -26,10 +35,13 @@ function nn.SpatialConvolutionMM:clearState()
end end
return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput') return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
end end
function srcnn.channels(model) function srcnn.channels(model)
if model.w2nn_channels ~= nil then
return model.w2nn_channels
else
return model:get(model:size() - 1).weight:size(1) return model:get(model:size() - 1).weight:size(1)
end end
end
function srcnn.backend(model) function srcnn.backend(model)
local conv = model:findModules("cudnn.SpatialConvolution") local conv = model:findModules("cudnn.SpatialConvolution")
if #conv > 0 then if #conv > 0 then
@ -47,10 +59,11 @@ function srcnn.color(model)
end end
end end
function srcnn.name(model) function srcnn.name(model)
local backend_cudnn = false if model.w2nn_arch_name then
return model.w2nn_arch_name
else
local conv = model:findModules("nn.SpatialConvolutionMM") local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then if #conv == 0 then
backend_cudnn = true
conv = model:findModules("cudnn.SpatialConvolution") conv = model:findModules("cudnn.SpatialConvolution")
end end
if #conv == 7 then if #conv == 7 then
@ -58,10 +71,16 @@ function srcnn.name(model)
elseif #conv == 12 then elseif #conv == 12 then
return "vgg_12" return "vgg_12"
else else
return nil error("unsupported model name")
end
end end
end end
function srcnn.offset_size(model) function srcnn.offset_size(model)
if model.w2nn_offset ~= nil then
return model.w2nn_offset
else
local name = srcnn.name(model)
if name:match("vgg_") then
local conv = model:findModules("nn.SpatialConvolutionMM") local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv == 0 then if #conv == 0 then
conv = model:findModules("cudnn.SpatialConvolution") conv = model:findModules("cudnn.SpatialConvolution")
@ -71,8 +90,23 @@ function srcnn.offset_size(model)
offset = offset + (conv[i].kW - 1) / 2 offset = offset + (conv[i].kW - 1) / 2
end end
return math.floor(offset) return math.floor(offset)
else
error("unsupported model name")
end
end
end
function srcnn.has_resize(model)
if model.w2nn_resize ~= nil then
return model.w2nn_resize
else
local name = srcnn.name(model)
if name:match("upconv") ~= nil then
return true
else
return false
end
end
end end
local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then if backend == "cunn" then
return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
@ -82,6 +116,15 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
error("unsupported backend:" .. backend) error("unsupported backend:" .. backend)
end end
end end
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
-- VGG style net(7 layers) -- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch) function srcnn.vgg_7(backend, ch)
@ -100,6 +143,11 @@ function srcnn.vgg_7(backend, ch)
model:add(w2nn.LeakyReLU(0.1)) model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3)) model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "vgg_7"
model.w2nn_offset = 7
model.w2nn_resize = false
model.w2nn_channels = ch
--model:cuda() --model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
@ -132,12 +180,103 @@ function srcnn.vgg_12(backend, ch)
model:add(w2nn.LeakyReLU(0.1)) model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3)) model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "vgg_12"
model.w2nn_offset = 12
model.w2nn_resize = false
model.w2nn_channels = ch
--model:cuda() --model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model return model
end end
-- Dilated Convolution (7 layers)
function srcnn.dilated_7(backend, ch)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
model.w2nn_arch_name = "dilated_7"
model.w2nn_offset = 12
model.w2nn_resize = false
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
-- Up Convolution
function srcnn.upconv_7(backend, ch)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
model.w2nn_arch_name = "upconv_7"
model.w2nn_offset = 12
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.upconv_8_4x(backend, ch)
local model = nn.Sequential()
model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
model.w2nn_arch_name = "upconv_8_4x"
model.w2nn_offset = 12
model.w2nn_resize = true
model.w2nn_channels = ch
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.create(model_name, backend, color) function srcnn.create(model_name, backend, color)
model_name = model_name or "vgg_7" model_name = model_name or "vgg_7"
backend = backend or "cunn" backend = backend or "cunn"
@ -150,12 +289,14 @@ function srcnn.create(model_name, backend, color)
else else
error("unsupported color: " .. color) error("unsupported color: " .. color)
end end
if model_name == "vgg_7" then if srcnn[model_name] then
return srcnn.vgg_7(backend, ch) return srcnn[model_name](backend, ch)
elseif model_name == "vgg_12" then
return srcnn.vgg_12(backend, ch)
else else
error("unsupported model_name: " .. model_name) error("unsupported model_name: " .. model_name)
end end
end end
--local model = srcnn.upconv_8_4x("cunn", 3):cuda()
--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
return srcnn return srcnn

View file

@ -15,7 +15,9 @@ local pairwise_transform = require 'pairwise_transform'
local image_loader = require 'image_loader' local image_loader = require 'image_loader'
local function save_test_scale(model, rgb, file) local function save_test_scale(model, rgb, file)
local up = reconstruct.scale(model, settings.scale, rgb, 128, settings.upsampling_filter) local up = reconstruct.scale(model, settings.scale, rgb,
settings.scale * settings.crop_size,
settings.upsampling_filter)
image.save(file, up) image.save(file, up)
end end
local function save_test_jpeg(model, rgb, file) local function save_test_jpeg(model, rgb, file)
@ -96,6 +98,7 @@ local function create_criterion(model)
local offset = reconstruct.offset_size(model) local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2 local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w) local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B weight[3]:fill(0.11448 * 3) -- B
@ -108,7 +111,7 @@ local function create_criterion(model)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
end end
end end
local function transformer(x, is_validation, n, offset) local function transformer(model, x, is_validation, n, offset)
x = compression.decompress(x) x = compression.decompress(x)
n = n or settings.patches n = n or settings.patches
@ -145,7 +148,8 @@ local function transformer(x, is_validation, n, offset)
active_cropping_rate = active_cropping_rate, active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries, active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb"), rgb = (settings.color == "rgb"),
gamma_correction = settings.gamma_correction gamma_correction = settings.gamma_correction,
x_upsampling = not srcnn.has_resize(model)
}) })
elseif settings.method == "noise" then elseif settings.method == "noise" then
return pairwise_transform.jpeg(x, return pairwise_transform.jpeg(x,
@ -183,6 +187,22 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
end end
end end
end end
local function remove_small_image(x)
local new_x = {}
for i = 1, #x do
local x_s = compression.size(x[i])
if x_s[2] / settings.scale > settings.crop_size + 16 and
x_s[3] / settings.scale > settings.crop_size + 16 then
table.insert(new_x, x[i])
end
if i % 100 == 0 then
collectgarbage()
end
end
print(string.format("removed %d small images", #x - #new_x))
return new_x
end
local function plot(train, valid) local function plot(train, valid)
gnuplot.plot({ gnuplot.plot({
{'training', torch.Tensor(train), '-'}, {'training', torch.Tensor(train), '-'},
@ -194,11 +214,11 @@ local function train()
local model = srcnn.create(settings.model, settings.backend, settings.color) local model = srcnn.create(settings.model, settings.backend, settings.color)
local offset = reconstruct.offset_size(model) local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n) local pairwise_func = function(x, is_validation, n)
return transformer(x, is_validation, n, offset) return transformer(model, x, is_validation, n, offset)
end end
local criterion = create_criterion(model) local criterion = create_criterion(model)
local eval_metric = nn.MSECriterion():cuda() local eval_metric = nn.MSECriterion():cuda()
local x = torch.load(settings.images) local x = remove_small_image(torch.load(settings.images))
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1)) local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
local adam_config = { local adam_config = {
learningRate = settings.learning_rate, learningRate = settings.learning_rate,
@ -222,10 +242,16 @@ local function train()
model:cuda() model:cuda()
print("load .. " .. #train_x) print("load .. " .. #train_x)
local x = torch.Tensor(settings.patches * #train_x, local x = nil
ch, settings.crop_size, settings.crop_size)
local y = torch.Tensor(settings.patches * #train_x, local y = torch.Tensor(settings.patches * #train_x,
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero() ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
if srcnn.has_resize(model) then
x = torch.Tensor(settings.patches * #train_x,
ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale)
else
x = torch.Tensor(settings.patches * #train_x,
ch, settings.crop_size, settings.crop_size)
end
for epoch = 1, settings.epoch do for epoch = 1, settings.epoch do
model:training() model:training()
print("# " .. epoch) print("# " .. epoch)