Add new models
upconv_7 is 2.3x faster than previous model
This commit is contained in:
parent
e62305377f
commit
51ae485cd1
|
@ -1,255 +1,9 @@
|
|||
require 'image'
|
||||
local gm = require 'graphicsmagick'
|
||||
local iproc = require 'iproc'
|
||||
local data_augmentation = require 'data_augmentation'
|
||||
|
||||
require 'pl'
|
||||
local pairwise_transform = {}
|
||||
|
||||
local function random_half(src, p, filters)
|
||||
if torch.uniform() < p then
|
||||
local filter = filters[torch.random(1, #filters)]
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
local function crop_if_large(src, max_size)
|
||||
local tries = 4
|
||||
if src:size(2) > max_size and src:size(3) > max_size then
|
||||
local rect
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, src:size(2) - max_size)
|
||||
local xi = torch.random(0, src:size(3) - max_size)
|
||||
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
-- ignore simple background
|
||||
if rect:float():std() >= 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
return rect
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
local function preprocess(src, crop_size, options)
|
||||
local dest = src
|
||||
dest = random_half(dest, options.random_half_rate, options.downsampling_filters)
|
||||
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
|
||||
return dest
|
||||
end
|
||||
local function active_cropping(x, y, size, p, tries)
|
||||
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
local r = torch.uniform()
|
||||
local t = "float"
|
||||
if x:type() == "torch.ByteTensor" then
|
||||
t = "byte"
|
||||
end
|
||||
if p < r then
|
||||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
|
||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
return xc, yc
|
||||
else
|
||||
local lowres = gm.Image(x, "RGB", "DHW"):
|
||||
size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
|
||||
size(x:size(3), x:size(2), "Box"):
|
||||
toTensor(t, "RGB", "DHW")
|
||||
local best_se = 0.0
|
||||
local best_xc, best_yc
|
||||
local m = torch.FloatTensor(x:size(1), size, size)
|
||||
for i = 1, tries do
|
||||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
|
||||
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
|
||||
local xcf = iproc.byte2float(xc)
|
||||
local lcf = iproc.byte2float(lc)
|
||||
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
|
||||
if se >= best_se then
|
||||
best_xc = xcf
|
||||
best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
|
||||
best_se = se
|
||||
end
|
||||
end
|
||||
return best_xc, best_yc
|
||||
end
|
||||
end
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = options.downsampling_filters
|
||||
local unstable_region_offset = 8
|
||||
local downsampling_filter = filters[torch.random(1, #filters)]
|
||||
local y = preprocess(src, size, options)
|
||||
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
||||
local down_scale = 1.0 / scale
|
||||
local x
|
||||
if options.gamma_correction then
|
||||
x = iproc.scale(iproc.scale_with_gamma22(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downsampling_filter),
|
||||
y:size(3), y:size(2), options.upsampling_filter)
|
||||
else
|
||||
x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downsampling_filter),
|
||||
y:size(3), y:size(2), options.upsampling_filter)
|
||||
end
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y,
|
||||
size,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc = iproc.byte2float(xc)
|
||||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||
local unstable_region_offset = 8
|
||||
local y = preprocess(src, size, options)
|
||||
local x = y
|
||||
pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_scale'))
|
||||
pairwise_transform = tablex.update(pairwise_transform, require('pairwise_transform_jpeg'))
|
||||
|
||||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg"):depth(8)
|
||||
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
|
||||
-- YUV 420
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
else
|
||||
-- YUV 444
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y, size,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc = iproc.byte2float(xc)
|
||||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
if torch.uniform() < options.nr_rate then
|
||||
-- reducing noise
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
else
|
||||
-- ratain useful details
|
||||
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
||||
if style == "art" then
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset, n, options)
|
||||
elseif level == 2 or level == 3 then
|
||||
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
|
||||
size, offset, n, options)
|
||||
elseif r > 0.3 then
|
||||
local quality1 = torch.random(37, 70)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset, n, options)
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
local quality2 = quality1 - torch.random(5, 15)
|
||||
local quality3 = quality1 - torch.random(15, 25)
|
||||
|
||||
return pairwise_transform.jpeg_(src,
|
||||
{quality1, quality2, quality3},
|
||||
size, offset, n, options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
elseif style == "photo" then
|
||||
-- level adjusting by -nr_rate
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
|
||||
size, offset, n,
|
||||
options)
|
||||
else
|
||||
error("unknown style: " .. style)
|
||||
end
|
||||
end
|
||||
print(pairwise_transform)
|
||||
|
||||
function pairwise_transform.test_jpeg(src)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
random_unsharp_mask_rate = 0.5,
|
||||
jpeg_chroma_subsampling_rate = 0.5,
|
||||
nr_rate = 1.0,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
rgb = true
|
||||
}
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
for i = 1, 9 do
|
||||
local xy = pairwise_transform.jpeg(src,
|
||||
"art",
|
||||
torch.random(1, 2),
|
||||
128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
|
||||
end
|
||||
end
|
||||
function pairwise_transform.test_scale(src)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
random_unsharp_mask_rate = 0.5,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
rgb = true
|
||||
}
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
|
||||
for i = 1, 10 do
|
||||
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||
end
|
||||
end
|
||||
return pairwise_transform
|
||||
|
|
117
lib/pairwise_transform_jpeg.lua
Normal file
117
lib/pairwise_transform_jpeg.lua
Normal file
|
@ -0,0 +1,117 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local gm = require 'graphicsmagick'
|
||||
local iproc = require 'iproc'
|
||||
local pairwise_transform = {}
|
||||
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||
local unstable_region_offset = 8
|
||||
local y = pairwise_utils.preprocess(src, size, options)
|
||||
local x = y
|
||||
|
||||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg"):depth(8)
|
||||
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
|
||||
-- YUV 420
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
else
|
||||
-- YUV 444
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = pairwise_utils.active_cropping(x, y, size, 1,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc = iproc.byte2float(xc)
|
||||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
if torch.uniform() < options.nr_rate then
|
||||
-- reducing noise
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
else
|
||||
-- ratain useful details
|
||||
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
||||
if style == "art" then
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset, n, options)
|
||||
elseif level == 2 or level == 3 then
|
||||
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
|
||||
size, offset, n, options)
|
||||
elseif r > 0.3 then
|
||||
local quality1 = torch.random(37, 70)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset, n, options)
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
local quality2 = quality1 - torch.random(5, 15)
|
||||
local quality3 = quality1 - torch.random(15, 25)
|
||||
|
||||
return pairwise_transform.jpeg_(src,
|
||||
{quality1, quality2, quality3},
|
||||
size, offset, n, options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
elseif style == "photo" then
|
||||
-- level adjusting by -nr_rate
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
|
||||
size, offset, n,
|
||||
options)
|
||||
else
|
||||
error("unknown style: " .. style)
|
||||
end
|
||||
end
|
||||
|
||||
function pairwise_transform.test_jpeg(src)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
random_unsharp_mask_rate = 0.5,
|
||||
jpeg_chroma_subsampling_rate = 0.5,
|
||||
nr_rate = 1.0,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
rgb = true
|
||||
}
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
for i = 1, 9 do
|
||||
local xy = pairwise_transform.jpeg(src,
|
||||
"art",
|
||||
torch.random(1, 2),
|
||||
128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
|
||||
end
|
||||
end
|
||||
return pairwise_transform
|
||||
|
86
lib/pairwise_transform_scale.lua
Normal file
86
lib/pairwise_transform_scale.lua
Normal file
|
@ -0,0 +1,86 @@
|
|||
local pairwise_utils = require 'pairwise_transform_utils'
|
||||
local iproc = require 'iproc'
|
||||
local pairwise_transform = {}
|
||||
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = options.downsampling_filters
|
||||
local unstable_region_offset = 8
|
||||
local downsampling_filter = filters[torch.random(1, #filters)]
|
||||
local y = pairwise_utils.preprocess(src, size, options)
|
||||
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
||||
local down_scale = 1.0 / scale
|
||||
local x
|
||||
if options.gamma_correction then
|
||||
local small = iproc.scale_with_gamma22(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downsampling_filter)
|
||||
if options.x_upsampling then
|
||||
x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
|
||||
else
|
||||
x = small
|
||||
end
|
||||
else
|
||||
local small = iproc.scale(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downsampling_filter)
|
||||
if options.x_upsampling then
|
||||
x = iproc.scale(small, y:size(3), y:size(2), options.upsampling_filter)
|
||||
else
|
||||
x = small
|
||||
end
|
||||
end
|
||||
|
||||
if options.x_upsampling then
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
else
|
||||
assert(x:size(1) == y:size(1) and x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
||||
end
|
||||
local scale_inner = scale
|
||||
if options.x_upsampling then
|
||||
scale_inner = 1
|
||||
end
|
||||
local batch = {}
|
||||
|
||||
for i = 1, n do
|
||||
local xc, yc = pairwise_utils.active_cropping(x, y,
|
||||
size,
|
||||
scale_inner,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc = iproc.byte2float(xc)
|
||||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.test_scale(src)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
random_unsharp_mask_rate = 0.5,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
x_upsampling = false,
|
||||
downsampling_filters = "Box",
|
||||
rgb = true
|
||||
}
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
|
||||
for i = 1, 10 do
|
||||
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||
end
|
||||
end
|
||||
return pairwise_transform
|
91
lib/pairwise_transform_utils.lua
Normal file
91
lib/pairwise_transform_utils.lua
Normal file
|
@ -0,0 +1,91 @@
|
|||
require 'image'
|
||||
local gm = require 'graphicsmagick'
|
||||
local iproc = require 'iproc'
|
||||
local data_augmentation = require 'data_augmentation'
|
||||
local pairwise_transform_utils = {}
|
||||
|
||||
function pairwise_transform_utils.random_half(src, p, filters)
|
||||
if torch.uniform() < p then
|
||||
local filter = filters[torch.random(1, #filters)]
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function pairwise_transform_utils.crop_if_large(src, max_size)
|
||||
local tries = 4
|
||||
if src:size(2) > max_size and src:size(3) > max_size then
|
||||
local rect
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, src:size(2) - max_size)
|
||||
local xi = torch.random(0, src:size(3) - max_size)
|
||||
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
-- ignore simple background
|
||||
if rect:float():std() >= 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
return rect
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function pairwise_transform_utils.preprocess(src, crop_size, options)
|
||||
local dest = src
|
||||
dest = pairwise_transform_utils.random_half(dest, options.random_half_rate, options.downsampling_filters)
|
||||
dest = pairwise_transform_utils.crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
|
||||
return dest
|
||||
end
|
||||
function pairwise_transform_utils.active_cropping(x, y, size, scale, p, tries)
|
||||
assert("x:size == y:size", x:size(2) * scale == y:size(2) and x:size(3) * scale == y:size(3))
|
||||
assert("crop_size % scale == 0", size % scale == 0)
|
||||
local r = torch.uniform()
|
||||
local t = "float"
|
||||
if x:type() == "torch.ByteTensor" then
|
||||
t = "byte"
|
||||
end
|
||||
if p < r then
|
||||
local xi = torch.random(0, x:size(3) - (size + 1))
|
||||
local yi = torch.random(0, x:size(2) - (size + 1))
|
||||
local yc = iproc.crop(y, xi * scale, yi * scale, xi * scale + size, yi * scale + size)
|
||||
local xc = iproc.crop(x, xi, yi, xi + size / scale, yi + size / scale)
|
||||
return xc, yc
|
||||
else
|
||||
local test_scale = 2
|
||||
if test_scale < scale then
|
||||
test_scale = scale
|
||||
end
|
||||
local lowres = gm.Image(y, "RGB", "DHW"):
|
||||
size(y:size(3) * 0.5, y:size(2) * 0.5, "Box"):
|
||||
size(y:size(3), y:size(2), "Box"):
|
||||
toTensor(t, "RGB", "DHW")
|
||||
local best_se = 0.0
|
||||
local best_xi, best_yi
|
||||
local m = torch.FloatTensor(y:size(1), size, size)
|
||||
for i = 1, tries do
|
||||
local xi = torch.random(0, x:size(3) - (size + 1)) * scale
|
||||
local yi = torch.random(0, x:size(2) - (size + 1)) * scale
|
||||
local xc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
|
||||
local xcf = iproc.byte2float(xc)
|
||||
local lcf = iproc.byte2float(lc)
|
||||
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
|
||||
if se >= best_se then
|
||||
best_xi = xi
|
||||
best_yi = yi
|
||||
best_se = se
|
||||
end
|
||||
end
|
||||
local yc = iproc.crop(y, best_xi, best_yi, best_xi + size, best_yi + size)
|
||||
local xc = iproc.crop(x, best_xi / scale, best_yi / scale, best_xi / scale + size / scale, best_yi / scale + size / scale)
|
||||
return xc, yc
|
||||
end
|
||||
end
|
||||
|
||||
return pairwise_transform_utils
|
|
@ -49,6 +49,32 @@ local function reconstruct_rgb(model, x, offset, block_size)
|
|||
end
|
||||
return new_x
|
||||
end
|
||||
local function reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
|
||||
local new_x = torch.Tensor(x:size(1), x:size(2) * scale, x:size(3) * scale):zero()
|
||||
local input_block_size = block_size / scale
|
||||
local output_block_size = block_size
|
||||
local output_size = output_block_size - offset * 2
|
||||
local output_size_in_input = input_block_size - offset
|
||||
local input = torch.CudaTensor(1, 3, input_block_size, input_block_size)
|
||||
|
||||
for i = 1, x:size(2), output_size_in_input do
|
||||
for j = 1, new_x:size(3), output_size_in_input do
|
||||
if i + input_block_size - 1 <= x:size(2) and j + input_block_size - 1 <= x:size(3) then
|
||||
local index = {{},
|
||||
{i, i + input_block_size - 1},
|
||||
{j, j + input_block_size - 1}}
|
||||
input:copy(x[index])
|
||||
local output = model:forward(input):view(3, output_size, output_size)
|
||||
local ii = (i - 1) * scale + 1
|
||||
local jj = (j - 1) * scale + 1
|
||||
local output_index = {{}, { ii , ii + output_size - 1 },
|
||||
{ jj, jj + output_size - 1}}
|
||||
new_x[output_index]:copy(output)
|
||||
end
|
||||
end
|
||||
end
|
||||
return new_x
|
||||
end
|
||||
local reconstruct = {}
|
||||
function reconstruct.is_rgb(model)
|
||||
if srcnn.channels(model) == 3 then
|
||||
|
@ -62,6 +88,9 @@ end
|
|||
function reconstruct.offset_size(model)
|
||||
return srcnn.offset_size(model)
|
||||
end
|
||||
function reconstruct.no_resize(model)
|
||||
return srcnn.has_resize(model)
|
||||
end
|
||||
function reconstruct.image_y(model, x, offset, block_size)
|
||||
block_size = block_size or 128
|
||||
local output_size = block_size - offset * 2
|
||||
|
@ -95,8 +124,14 @@ end
|
|||
function reconstruct.scale_y(model, scale, x, offset, block_size, upsampling_filter)
|
||||
upsampling_filter = upsampling_filter or "Box"
|
||||
block_size = block_size or 128
|
||||
local x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
|
||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
|
||||
|
||||
local x_lanczos
|
||||
if reconstruct.no_resize(model) then
|
||||
x_lanczos = x:clone()
|
||||
else
|
||||
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
|
||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
|
||||
end
|
||||
if x:size(2) * x:size(3) > 2048*2048 then
|
||||
collectgarbage()
|
||||
end
|
||||
|
@ -162,39 +197,77 @@ function reconstruct.image_rgb(model, x, offset, block_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.scale_rgb(model, scale, x, offset, block_size, upsampling_filter)
|
||||
upsampling_filter = upsampling_filter or "Box"
|
||||
block_size = block_size or 128
|
||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
|
||||
if x:size(2) * x:size(3) > 2048*2048 then
|
||||
collectgarbage()
|
||||
end
|
||||
local output_size = block_size - offset * 2
|
||||
local h_blocks = math.floor(x:size(2) / output_size) +
|
||||
((x:size(2) % output_size == 0 and 0) or 1)
|
||||
local w_blocks = math.floor(x:size(3) / output_size) +
|
||||
((x:size(3) % output_size == 0 and 0) or 1)
|
||||
|
||||
local h = offset + h_blocks * output_size + offset
|
||||
local w = offset + w_blocks * output_size + offset
|
||||
local pad_h1 = offset
|
||||
local pad_w1 = offset
|
||||
local pad_h2 = (h - offset) - x:size(2)
|
||||
local pad_w2 = (w - offset) - x:size(3)
|
||||
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||
if x:size(2) * x:size(3) > 2048*2048 then
|
||||
collectgarbage()
|
||||
end
|
||||
local y = reconstruct_rgb(model, x, offset, block_size)
|
||||
local output = iproc.crop(y,
|
||||
pad_w1, pad_h1,
|
||||
y:size(3) - pad_w2, y:size(2) - pad_h2)
|
||||
output[torch.lt(output, 0)] = 0
|
||||
output[torch.gt(output, 1)] = 1
|
||||
x = nil
|
||||
y = nil
|
||||
collectgarbage()
|
||||
if reconstruct.no_resize(model) then
|
||||
block_size = block_size or 128
|
||||
local input_block_size = block_size / scale
|
||||
local x_w = x:size(3)
|
||||
local x_h = x:size(2)
|
||||
local process_size = input_block_size - offset * 2
|
||||
-- TODO: under construction!! bug in 4x
|
||||
local h_blocks = math.floor(x_h / process_size) + 2
|
||||
-- ((x_h % process_size == 0 and 0) or 1)
|
||||
local w_blocks = math.floor(x_w / process_size) + 2
|
||||
-- ((x_w % process_size == 0 and 0) or 1)
|
||||
local h = offset + (h_blocks * process_size) + offset
|
||||
local w = offset + (w_blocks * process_size) + offset
|
||||
local pad_h1 = offset
|
||||
local pad_w1 = offset
|
||||
|
||||
return output
|
||||
local pad_h2 = (h - offset) - x:size(2)
|
||||
local pad_w2 = (w - offset) - x:size(3)
|
||||
|
||||
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||
if x:size(2) * x:size(3) > 2048*2048 then
|
||||
collectgarbage()
|
||||
end
|
||||
local y
|
||||
y = reconstruct_rgb_with_scale(model, x, scale, offset, block_size)
|
||||
local output = iproc.crop(y,
|
||||
pad_w1, pad_h1,
|
||||
pad_w1 + x_w * scale, pad_h1 + x_h * scale)
|
||||
output[torch.lt(output, 0)] = 0
|
||||
output[torch.gt(output, 1)] = 1
|
||||
x = nil
|
||||
y = nil
|
||||
collectgarbage()
|
||||
|
||||
return output
|
||||
else
|
||||
upsampling_filter = upsampling_filter or "Box"
|
||||
block_size = block_size or 128
|
||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, upsampling_filter)
|
||||
if x:size(2) * x:size(3) > 2048*2048 then
|
||||
collectgarbage()
|
||||
end
|
||||
local output_size = block_size - offset * 2
|
||||
local h_blocks = math.floor(x:size(2) / output_size) +
|
||||
((x:size(2) % output_size == 0 and 0) or 1)
|
||||
local w_blocks = math.floor(x:size(3) / output_size) +
|
||||
((x:size(3) % output_size == 0 and 0) or 1)
|
||||
|
||||
local h = offset + h_blocks * output_size + offset
|
||||
local w = offset + w_blocks * output_size + offset
|
||||
local pad_h1 = offset
|
||||
local pad_w1 = offset
|
||||
local pad_h2 = (h - offset) - x:size(2)
|
||||
local pad_w2 = (w - offset) - x:size(3)
|
||||
x = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||
if x:size(2) * x:size(3) > 2048*2048 then
|
||||
collectgarbage()
|
||||
end
|
||||
local y
|
||||
y = reconstruct_rgb(model, x, offset, block_size)
|
||||
local output = iproc.crop(y,
|
||||
pad_w1, pad_h1,
|
||||
y:size(3) - pad_w2, y:size(2) - pad_h2)
|
||||
output[torch.lt(output, 0)] = 0
|
||||
output[torch.gt(output, 1)] = 1
|
||||
x = nil
|
||||
y = nil
|
||||
collectgarbage()
|
||||
|
||||
return output
|
||||
end
|
||||
end
|
||||
|
||||
function reconstruct.image(model, x, block_size)
|
||||
|
|
|
@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
|||
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||
cmd:option("-model_dir", "./models", 'model directory')
|
||||
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12)')
|
||||
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
|
||||
cmd:option("-noise_level", 1, '(1|2|3)')
|
||||
cmd:option("-style", "art", '(art|photo)')
|
||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||
|
@ -34,7 +34,7 @@ cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution im
|
|||
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.0005, 'learning rate for adam')
|
||||
cmd:option("-crop_size", 46, 'crop size')
|
||||
cmd:option("-crop_size", 48, 'crop size')
|
||||
cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
|
||||
cmd:option("-batch_size", 8, 'mini batch size')
|
||||
cmd:option("-patches", 16, 'number of patch samples')
|
||||
|
|
195
lib/srcnn.lua
195
lib/srcnn.lua
|
@ -9,14 +9,23 @@ function nn.SpatialConvolutionMM:reset(stdv)
|
|||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
end
|
||||
function nn.SpatialFullConvolution:reset(stdv)
|
||||
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
end
|
||||
if cudnn and cudnn.SpatialConvolution then
|
||||
function cudnn.SpatialConvolution:reset(stdv)
|
||||
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
end
|
||||
function cudnn.SpatialFullConvolution:reset(stdv)
|
||||
stdv = math.sqrt(2 / ((1.0 + 0.1 * 0.1) * self.kW * self.kH * self.nOutputPlane))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
end
|
||||
end
|
||||
|
||||
function nn.SpatialConvolutionMM:clearState()
|
||||
if self.gradWeight then
|
||||
self.gradWeight:resize(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):zero()
|
||||
|
@ -26,9 +35,12 @@ function nn.SpatialConvolutionMM:clearState()
|
|||
end
|
||||
return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
|
||||
end
|
||||
|
||||
function srcnn.channels(model)
|
||||
return model:get(model:size() - 1).weight:size(1)
|
||||
if model.w2nn_channels ~= nil then
|
||||
return model.w2nn_channels
|
||||
else
|
||||
return model:get(model:size() - 1).weight:size(1)
|
||||
end
|
||||
end
|
||||
function srcnn.backend(model)
|
||||
local conv = model:findModules("cudnn.SpatialConvolution")
|
||||
|
@ -47,32 +59,54 @@ function srcnn.color(model)
|
|||
end
|
||||
end
|
||||
function srcnn.name(model)
|
||||
local backend_cudnn = false
|
||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #conv == 0 then
|
||||
backend_cudnn = true
|
||||
conv = model:findModules("cudnn.SpatialConvolution")
|
||||
end
|
||||
if #conv == 7 then
|
||||
return "vgg_7"
|
||||
elseif #conv == 12 then
|
||||
return "vgg_12"
|
||||
if model.w2nn_arch_name then
|
||||
return model.w2nn_arch_name
|
||||
else
|
||||
return nil
|
||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #conv == 0 then
|
||||
conv = model:findModules("cudnn.SpatialConvolution")
|
||||
end
|
||||
if #conv == 7 then
|
||||
return "vgg_7"
|
||||
elseif #conv == 12 then
|
||||
return "vgg_12"
|
||||
else
|
||||
error("unsupported model name")
|
||||
end
|
||||
end
|
||||
end
|
||||
function srcnn.offset_size(model)
|
||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #conv == 0 then
|
||||
conv = model:findModules("cudnn.SpatialConvolution")
|
||||
if model.w2nn_offset ~= nil then
|
||||
return model.w2nn_offset
|
||||
else
|
||||
local name = srcnn.name(model)
|
||||
if name:match("vgg_") then
|
||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #conv == 0 then
|
||||
conv = model:findModules("cudnn.SpatialConvolution")
|
||||
end
|
||||
local offset = 0
|
||||
for i = 1, #conv do
|
||||
offset = offset + (conv[i].kW - 1) / 2
|
||||
end
|
||||
return math.floor(offset)
|
||||
else
|
||||
error("unsupported model name")
|
||||
end
|
||||
end
|
||||
end
|
||||
function srcnn.has_resize(model)
|
||||
if model.w2nn_resize ~= nil then
|
||||
return model.w2nn_resize
|
||||
else
|
||||
local name = srcnn.name(model)
|
||||
if name:match("upconv") ~= nil then
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
||||
end
|
||||
local offset = 0
|
||||
for i = 1, #conv do
|
||||
offset = offset + (conv[i].kW - 1) / 2
|
||||
end
|
||||
return math.floor(offset)
|
||||
end
|
||||
|
||||
local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialConvolutionMM(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||
|
@ -82,6 +116,15 @@ local function SpatialConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW
|
|||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
local function SpatialFullConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||
elseif backend == "cudnn" then
|
||||
return cudnn.SpatialFullConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
|
||||
-- VGG style net(7 layers)
|
||||
function srcnn.vgg_7(backend, ch)
|
||||
|
@ -100,6 +143,11 @@ function srcnn.vgg_7(backend, ch)
|
|||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "vgg_7"
|
||||
model.w2nn_offset = 7
|
||||
model.w2nn_resize = false
|
||||
model.w2nn_channels = ch
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
|
@ -132,12 +180,103 @@ function srcnn.vgg_12(backend, ch)
|
|||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "vgg_12"
|
||||
model.w2nn_offset = 12
|
||||
model.w2nn_resize = false
|
||||
model.w2nn_channels = ch
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- Dilated Convolution (7 layers)
|
||||
function srcnn.dilated_7(backend, ch)
|
||||
local model = nn.Sequential()
|
||||
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
model.w2nn_arch_name = "dilated_7"
|
||||
model.w2nn_offset = 12
|
||||
model.w2nn_resize = false
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
|
||||
-- Up Convolution
|
||||
function srcnn.upconv_7(backend, ch)
|
||||
local model = nn.Sequential()
|
||||
|
||||
model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialFullConvolution(backend, 128, ch, 4, 4, 2, 2, 1, 1))
|
||||
|
||||
model.w2nn_arch_name = "upconv_7"
|
||||
model.w2nn_offset = 12
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
function srcnn.upconv_8_4x(backend, ch)
|
||||
local model = nn.Sequential()
|
||||
|
||||
model:add(SpatialFullConvolution(backend, ch, 32, 4, 4, 2, 2, 1, 1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(SpatialFullConvolution(backend, 64, 3, 4, 4, 2, 2, 1, 1))
|
||||
|
||||
model.w2nn_arch_name = "upconv_8_4x"
|
||||
model.w2nn_offset = 12
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_channels = ch
|
||||
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
function srcnn.create(model_name, backend, color)
|
||||
model_name = model_name or "vgg_7"
|
||||
backend = backend or "cunn"
|
||||
|
@ -150,12 +289,14 @@ function srcnn.create(model_name, backend, color)
|
|||
else
|
||||
error("unsupported color: " .. color)
|
||||
end
|
||||
if model_name == "vgg_7" then
|
||||
return srcnn.vgg_7(backend, ch)
|
||||
elseif model_name == "vgg_12" then
|
||||
return srcnn.vgg_12(backend, ch)
|
||||
if srcnn[model_name] then
|
||||
return srcnn[model_name](backend, ch)
|
||||
else
|
||||
error("unsupported model_name: " .. model_name)
|
||||
end
|
||||
end
|
||||
|
||||
--local model = srcnn.upconv_8_4x("cunn", 3):cuda()
|
||||
--print(model:forward(torch.Tensor(1, 3, 64, 64):zero():cuda()):size())
|
||||
|
||||
return srcnn
|
||||
|
|
40
train.lua
40
train.lua
|
@ -15,7 +15,9 @@ local pairwise_transform = require 'pairwise_transform'
|
|||
local image_loader = require 'image_loader'
|
||||
|
||||
local function save_test_scale(model, rgb, file)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb, 128, settings.upsampling_filter)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb,
|
||||
settings.scale * settings.crop_size,
|
||||
settings.upsampling_filter)
|
||||
image.save(file, up)
|
||||
end
|
||||
local function save_test_jpeg(model, rgb, file)
|
||||
|
@ -96,6 +98,7 @@ local function create_criterion(model)
|
|||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
|
@ -108,7 +111,7 @@ local function create_criterion(model)
|
|||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
end
|
||||
end
|
||||
local function transformer(x, is_validation, n, offset)
|
||||
local function transformer(model, x, is_validation, n, offset)
|
||||
x = compression.decompress(x)
|
||||
n = n or settings.patches
|
||||
|
||||
|
@ -145,7 +148,8 @@ local function transformer(x, is_validation, n, offset)
|
|||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb"),
|
||||
gamma_correction = settings.gamma_correction
|
||||
gamma_correction = settings.gamma_correction,
|
||||
x_upsampling = not srcnn.has_resize(model)
|
||||
})
|
||||
elseif settings.method == "noise" then
|
||||
return pairwise_transform.jpeg(x,
|
||||
|
@ -183,6 +187,22 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
|
|||
end
|
||||
end
|
||||
end
|
||||
local function remove_small_image(x)
|
||||
local new_x = {}
|
||||
for i = 1, #x do
|
||||
local x_s = compression.size(x[i])
|
||||
if x_s[2] / settings.scale > settings.crop_size + 16 and
|
||||
x_s[3] / settings.scale > settings.crop_size + 16 then
|
||||
table.insert(new_x, x[i])
|
||||
end
|
||||
if i % 100 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
print(string.format("removed %d small images", #x - #new_x))
|
||||
|
||||
return new_x
|
||||
end
|
||||
local function plot(train, valid)
|
||||
gnuplot.plot({
|
||||
{'training', torch.Tensor(train), '-'},
|
||||
|
@ -194,11 +214,11 @@ local function train()
|
|||
local model = srcnn.create(settings.model, settings.backend, settings.color)
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local pairwise_func = function(x, is_validation, n)
|
||||
return transformer(x, is_validation, n, offset)
|
||||
return transformer(model, x, is_validation, n, offset)
|
||||
end
|
||||
local criterion = create_criterion(model)
|
||||
local eval_metric = nn.MSECriterion():cuda()
|
||||
local x = torch.load(settings.images)
|
||||
local x = remove_small_image(torch.load(settings.images))
|
||||
local train_x, valid_x = split_data(x, math.max(math.floor(settings.validation_rate * #x), 1))
|
||||
local adam_config = {
|
||||
learningRate = settings.learning_rate,
|
||||
|
@ -222,10 +242,16 @@ local function train()
|
|||
model:cuda()
|
||||
print("load .. " .. #train_x)
|
||||
|
||||
local x = torch.Tensor(settings.patches * #train_x,
|
||||
ch, settings.crop_size, settings.crop_size)
|
||||
local x = nil
|
||||
local y = torch.Tensor(settings.patches * #train_x,
|
||||
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
|
||||
if srcnn.has_resize(model) then
|
||||
x = torch.Tensor(settings.patches * #train_x,
|
||||
ch, settings.crop_size / settings.scale, settings.crop_size / settings.scale)
|
||||
else
|
||||
x = torch.Tensor(settings.patches * #train_x,
|
||||
ch, settings.crop_size, settings.crop_size)
|
||||
end
|
||||
for epoch = 1, settings.epoch do
|
||||
model:training()
|
||||
print("# " .. epoch)
|
||||
|
|
Loading…
Reference in a new issue