1
0
Fork 0
mirror of synced 2024-06-14 00:44:32 +12:00

Merge pull request #32 from nagadomi/rgb

add support for RGB color space reconstruction
This commit is contained in:
nagadomi 2015-06-23 03:15:40 +09:00
commit 81529b1ab8
27 changed files with 194 additions and 50 deletions

View file

@ -5,7 +5,7 @@ require './lib/LeakyReLU'
local srcnn = require 'lib/srcnn'
local function cudnn2cunn(cudnn_model)
local cunn_model = srcnn.waifu2x()
local cunn_model = srcnn.waifu2x("y")
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 383 KiB

After

Width:  |  Height:  |  Size: 315 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 648 KiB

After

Width:  |  Height:  |  Size: 605 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 148 KiB

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 493 KiB

After

Width:  |  Height:  |  Size: 499 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 372 KiB

After

Width:  |  Height:  |  Size: 380 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 377 KiB

After

Width:  |  Height:  |  Size: 378 KiB

View file

@ -73,12 +73,18 @@ function image_loader.decode_byte(blob)
end
function image_loader.load_float(file)
local fp = io.open(file, "rb")
if not fp then
error(file .. ": failed to load image")
end
local buff = fp:read("*a")
fp:close()
return image_loader.decode_float(buff)
end
function image_loader.load_byte(file)
local fp = io.open(file, "rb")
if not fp then
error(file .. ": failed to load image")
end
local buff = fp:read("*a")
fp:close()
return image_loader.decode_byte(buff)

View file

@ -2,15 +2,6 @@ local gm = require 'graphicsmagick'
local image = require 'image'
local iproc = {}
function iproc.sample(src, width, height)
local t = "float"
if src:type() == "torch.ByteTensor" then
t = "byte"
end
local im = gm.Image(src, "RGB", "DHW")
im:sample(math.ceil(width), math.ceil(height))
return im:toTensor(t, "RGB", "DHW")
end
function iproc.scale(src, width, height, filter)
local t = "float"
if src:type() == "torch.ByteTensor" then

View file

@ -52,7 +52,7 @@ local function flip_augment(x, y)
end
local INTERPOLATION_PADDING = 16
function pairwise_transform.scale(src, scale, size, offset, options)
options = options or {color_augment = true, random_half = true}
options = options or {color_augment = true, random_half = true, rgb = true}
if options.random_half then
src = random_half(src)
end
@ -81,8 +81,12 @@ function pairwise_transform.scale(src, scale, size, offset, options)
x = iproc.scale(x, y:size(3), y:size(2))
y = y:float():div(255)
x = x:float():div(255)
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
@ -90,7 +94,7 @@ function pairwise_transform.scale(src, scale, size, offset, options)
return x, y
end
function pairwise_transform.jpeg_(src, quality, size, offset, options)
options = options or {color_augment = true, random_half = true}
options = options or {color_augment = true, random_half = true, rgb = true}
if options.random_half then
src = random_half(src)
end
@ -106,6 +110,7 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
@ -117,9 +122,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
x = x:float():div(255)
x, y = flip_augment(x, y)
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
return x, image.crop(y, offset, offset, size - offset, size - offset)
end
function pairwise_transform.jpeg(src, level, size, offset, options)
@ -159,12 +167,12 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
local down_scale = 1.0 / scale
local filters = {
"Box", -- 0.012756949974688
--"Blackman", -- 0.013191924552285
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
--"SincFast", -- 0.014095824314306
--"Jinc", -- 0.014244299255442
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
local downscale_filter = filters[torch.random(1, #filters)]
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
@ -180,6 +188,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
@ -194,10 +203,13 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
x = x:float():div(255)
y = y:float():div(255)
x, y = flip_augment(x, y)
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
return x, image.crop(y, offset, offset, size - offset, size - offset)
end
function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
@ -247,7 +259,7 @@ local function test_scale()
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
for i = 1, 9 do
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true})
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
@ -272,8 +284,8 @@ local function test_jpeg_scale()
--print(x:mean(), y:mean())
end
end
--test_jpeg()
--test_scale()
--test_jpeg()
--test_jpeg_scale()
return pairwise_transform

View file

@ -1,7 +1,7 @@
require 'image'
local iproc = require './iproc'
local function reconstruct_layer(model, x, offset, block_size)
local function reconstruct_y(model, x, offset, block_size)
if x:dim() == 2 then
x = x:reshape(1, x:size(1), x:size(2))
end
@ -26,8 +26,40 @@ local function reconstruct_layer(model, x, offset, block_size)
end
return new_x
end
local function reconstruct_rgb(model, x, offset, block_size)
local new_x = torch.Tensor():resizeAs(x):zero()
local output_size = block_size - offset * 2
local input = torch.CudaTensor(1, 3, block_size, block_size)
for i = 1, x:size(2), output_size do
for j = 1, x:size(3), output_size do
if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
local index = {{},
{i, i + block_size - 1},
{j, j + block_size - 1}}
input:copy(x[index])
local output = model:forward(input):float():view(3, output_size, output_size)
local output_index = {{},
{i + offset, offset + i + output_size - 1},
{offset + j, offset + j + output_size - 1}}
new_x[output_index]:copy(output)
end
end
end
return new_x
end
function model_is_rgb(model)
if model:get(model:size() - 1).weight:size(1) == 3 then
-- 3ch RGB
return true
else
-- 1ch Y
return false
end
end
local reconstruct = {}
function reconstruct.image(model, x, offset, block_size)
function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
@ -42,7 +74,7 @@ function reconstruct.image(model, x, offset, block_size)
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
local y = reconstruct_layer(model, yuv[1], offset, block_size)
local y = reconstruct_y(model, yuv[1], offset, block_size)
y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1
yuv[1]:copy(y)
@ -55,7 +87,7 @@ function reconstruct.image(model, x, offset, block_size)
return output
end
function reconstruct.scale(model, scale, x, offset, block_size)
function reconstruct.scale_y(model, scale, x, offset, block_size)
block_size = block_size or 128
local x_jinc = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Jinc")
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
@ -74,7 +106,7 @@ function reconstruct.scale(model, scale, x, offset, block_size)
local pad_w2 = (w - offset) - x:size(3)
local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2))
local y = reconstruct_layer(model, yuv_nn[1], offset, block_size)
local y = reconstruct_y(model, yuv_nn[1], offset, block_size)
y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1
yuv_jinc[1]:copy(y)
@ -87,5 +119,72 @@ function reconstruct.scale(model, scale, x, offset, block_size)
return output
end
function reconstruct.image_rgb(model, x, offset, block_size)
block_size = block_size or 128
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
((x:size(2) % output_size == 0 and 0) or 1)
local w_blocks = math.floor(x:size(3) / output_size) +
((x:size(3) % output_size == 0 and 0) or 1)
local h = offset + h_blocks * output_size + offset
local w = offset + w_blocks * output_size + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
local y = reconstruct_rgb(model, input, offset, block_size)
local output = image.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
collectgarbage()
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
return output
end
function reconstruct.scale_rgb(model, scale, x, offset, block_size)
block_size = block_size or 128
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
((x:size(2) % output_size == 0 and 0) or 1)
local w_blocks = math.floor(x:size(3) / output_size) +
((x:size(3) % output_size == 0 and 0) or 1)
local h = offset + h_blocks * output_size + offset
local w = offset + w_blocks * output_size + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
local y = reconstruct_rgb(model, input, offset, block_size)
local output = image.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
collectgarbage()
return output
end
function reconstruct.image(model, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.image_rgb(model, x, offset, block_size)
else
return reconstruct.image_y(model, x, offset, block_size)
end
end
function reconstruct.scale(model, scale, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
else
return reconstruct.scale_y(model, scale, x, offset, block_size)
end
end
return reconstruct

View file

@ -22,6 +22,7 @@ cmd:option("-test", "images/miku_small.png", 'test image file')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-scale", 2.0, 'scale')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
@ -46,6 +47,9 @@ elseif settings.method == "noise_scale" then
else
error("unknown method: " .. settings.method)
end
if not (settings.color == "rgb" or settings.color == "y") then
error("color must be y or rgb")
end
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
error("scale must be mod-2")
end

View file

@ -6,10 +6,22 @@ function nn.SpatialConvolutionMM:reset(stdv)
self.bias:fill(0)
end
local srcnn = {}
function srcnn.waifu2x()
function srcnn.waifu2x(color)
local model = nn.Sequential()
local ch = nil
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
if color then
error("unknown color: " .. color)
else
error("unknown color: nil")
end
end
model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
@ -21,7 +33,7 @@ function srcnn.waifu2x()
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0))
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
@ -30,10 +42,19 @@ function srcnn.waifu2x()
end
-- current 4x is worse then 2x * 2
function srcnn.waifu4x()
function srcnn.waifu4x(color)
local model = nn.Sequential()
local ch = nil
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
error("unknown color: " .. color)
end
model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0))
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
@ -45,7 +66,7 @@ function srcnn.waifu4x()
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0))
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
return model, 13

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -59,7 +59,7 @@ local function validate(model, criterion, data)
end
local function train()
local model, offset = settings.create_model()
local model, offset = settings.create_model(settings.color)
assert(offset == settings.block_offset)
local criterion = nn.MSECriterion():cuda()
local x = torch.load(settings.images)
@ -72,6 +72,12 @@ local function train()
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
}
local ch = nil
if settings.color == "y" then
ch = 1
elseif settings.color == "rgb" then
ch = 3
end
local transformer = function(x, is_validation)
if is_validation == nil then is_validation = false end
if settings.method == "scale" then
@ -79,20 +85,25 @@ local function train()
settings.scale,
settings.crop_size, offset,
{ color_augment = not is_validation,
random_half = settings.random_half})
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
settings.noise_level,
settings.crop_size, offset,
{ color_augment = not is_validation,
random_half = settings.random_half})
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise_scale" then
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.noise_level,
settings.crop_size, offset,
{ color_augment = not is_validation,
random_half = settings.random_half
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
end
end
@ -109,8 +120,8 @@ local function train()
print("# " .. epoch)
print(minibatch_adam(model, criterion, train_x, adam_config,
transformer,
{1, settings.crop_size, settings.crop_size},
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
{ch, settings.crop_size, settings.crop_size},
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
))
model:evaluate()
print("# validation")

View file

@ -1,10 +1,10 @@
#!/bin/sh
th train.lua -method noise -noise_level 1 -model_dir models/anime_style_art -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art/noise1_model.t7 -oformat ascii
th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -method noise -noise_level 2 -model_dir models/anime_style_art -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art/noise2_model.t7 -oformat ascii
th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
th train.lua -method scale -scale 2 -model_dir models/anime_style_art -test images/miku_small.png
th cleanup_model.lua -model models/anime_style_art/scale2.0x_model.t7 -oformat ascii
th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii

View file

@ -105,7 +105,7 @@ local function waifu2x()
cmd:option("-l", "", 'path of the image-list')
cmd:option("-scale", 2, 'scale factor')
cmd:option("-o", "(auto)", 'path of the output file')
cmd:option("-model_dir", "./models/anime_style_art", 'model directory')
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-crop_size", 128, 'patch size per process')

View file

@ -15,7 +15,7 @@ local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local MODEL_DIR = "./models/anime_style_art"
local MODEL_DIR = "./models/anime_style_art_rgb"
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")