1
0
Fork 0
mirror of synced 2024-09-29 08:51:11 +13:00
This commit is contained in:
nagadomi 2015-06-22 18:17:50 +00:00
commit 6a91d7d3b4
27 changed files with 194 additions and 50 deletions

View file

@ -5,7 +5,7 @@ require './lib/LeakyReLU'
local srcnn = require 'lib/srcnn' local srcnn = require 'lib/srcnn'
local function cudnn2cunn(cudnn_model) local function cudnn2cunn(cudnn_model)
local cunn_model = srcnn.waifu2x() local cunn_model = srcnn.waifu2x("y")
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution") local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM") local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")

Binary file not shown.

Before

Width:  |  Height:  |  Size: 383 KiB

After

Width:  |  Height:  |  Size: 315 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 648 KiB

After

Width:  |  Height:  |  Size: 605 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 148 KiB

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 493 KiB

After

Width:  |  Height:  |  Size: 499 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 372 KiB

After

Width:  |  Height:  |  Size: 380 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 377 KiB

After

Width:  |  Height:  |  Size: 378 KiB

View file

@ -73,12 +73,18 @@ function image_loader.decode_byte(blob)
end end
function image_loader.load_float(file) function image_loader.load_float(file)
local fp = io.open(file, "rb") local fp = io.open(file, "rb")
if not fp then
error(file .. ": failed to load image")
end
local buff = fp:read("*a") local buff = fp:read("*a")
fp:close() fp:close()
return image_loader.decode_float(buff) return image_loader.decode_float(buff)
end end
function image_loader.load_byte(file) function image_loader.load_byte(file)
local fp = io.open(file, "rb") local fp = io.open(file, "rb")
if not fp then
error(file .. ": failed to load image")
end
local buff = fp:read("*a") local buff = fp:read("*a")
fp:close() fp:close()
return image_loader.decode_byte(buff) return image_loader.decode_byte(buff)

View file

@ -2,15 +2,6 @@ local gm = require 'graphicsmagick'
local image = require 'image' local image = require 'image'
local iproc = {} local iproc = {}
function iproc.sample(src, width, height)
local t = "float"
if src:type() == "torch.ByteTensor" then
t = "byte"
end
local im = gm.Image(src, "RGB", "DHW")
im:sample(math.ceil(width), math.ceil(height))
return im:toTensor(t, "RGB", "DHW")
end
function iproc.scale(src, width, height, filter) function iproc.scale(src, width, height, filter)
local t = "float" local t = "float"
if src:type() == "torch.ByteTensor" then if src:type() == "torch.ByteTensor" then

View file

@ -52,7 +52,7 @@ local function flip_augment(x, y)
end end
local INTERPOLATION_PADDING = 16 local INTERPOLATION_PADDING = 16
function pairwise_transform.scale(src, scale, size, offset, options) function pairwise_transform.scale(src, scale, size, offset, options)
options = options or {color_augment = true, random_half = true} options = options or {color_augment = true, random_half = true, rgb = true}
if options.random_half then if options.random_half then
src = random_half(src) src = random_half(src)
end end
@ -81,8 +81,12 @@ function pairwise_transform.scale(src, scale, size, offset, options)
x = iproc.scale(x, y:size(3), y:size(2)) x = iproc.scale(x, y:size(3), y:size(2))
y = y:float():div(255) y = y:float():div(255)
x = x:float():div(255) x = x:float():div(255)
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING) y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING) x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
@ -90,7 +94,7 @@ function pairwise_transform.scale(src, scale, size, offset, options)
return x, y return x, y
end end
function pairwise_transform.jpeg_(src, quality, size, offset, options) function pairwise_transform.jpeg_(src, quality, size, offset, options)
options = options or {color_augment = true, random_half = true} options = options or {color_augment = true, random_half = true, rgb = true}
if options.random_half then if options.random_half then
src = random_half(src) src = random_half(src)
end end
@ -106,6 +110,7 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
for i = 1, #quality do for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW") x = gm.Image(x, "RGB", "DHW")
x:format("jpeg") x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
local blob, len = x:toBlob(quality[i]) local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len) x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW") x = x:toTensor("byte", "RGB", "DHW")
@ -117,9 +122,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
x = x:float():div(255) x = x:float():div(255)
x, y = flip_augment(x, y) x, y = flip_augment(x, y)
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) if options.rgb then
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
return x, image.crop(y, offset, offset, size - offset, size - offset) return x, image.crop(y, offset, offset, size - offset, size - offset)
end end
function pairwise_transform.jpeg(src, level, size, offset, options) function pairwise_transform.jpeg(src, level, size, offset, options)
@ -159,12 +167,12 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
local down_scale = 1.0 / scale local down_scale = 1.0 / scale
local filters = { local filters = {
"Box", -- 0.012756949974688 "Box", -- 0.012756949974688
--"Blackman", -- 0.013191924552285 "Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706 --"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647 --"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266 --"Hermite", -- 0.013850225205266
--"SincFast", -- 0.014095824314306 "SincFast", -- 0.014095824314306
--"Jinc", -- 0.014244299255442 "Jinc", -- 0.014244299255442
} }
local downscale_filter = filters[torch.random(1, #filters)] local downscale_filter = filters[torch.random(1, #filters)]
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING) local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
@ -180,6 +188,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
for i = 1, #quality do for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW") x = gm.Image(x, "RGB", "DHW")
x:format("jpeg") x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
local blob, len = x:toBlob(quality[i]) local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len) x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW") x = x:toTensor("byte", "RGB", "DHW")
@ -194,10 +203,13 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
x = x:float():div(255) x = x:float():div(255)
y = y:float():div(255) y = y:float():div(255)
x, y = flip_augment(x, y) x, y = flip_augment(x, y)
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
return x, image.crop(y, offset, offset, size - offset, size - offset) return x, image.crop(y, offset, offset, size - offset, size - offset)
end end
function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options) function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
@ -247,7 +259,7 @@ local function test_scale()
local loader = require './image_loader' local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
for i = 1, 9 do for i = 1, 9 do
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true}) local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1}) image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1}) image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size()) print(y:size(), x:size())
@ -272,8 +284,8 @@ local function test_jpeg_scale()
--print(x:mean(), y:mean()) --print(x:mean(), y:mean())
end end
end end
--test_jpeg()
--test_scale() --test_scale()
--test_jpeg()
--test_jpeg_scale() --test_jpeg_scale()
return pairwise_transform return pairwise_transform

View file

@ -1,7 +1,7 @@
require 'image' require 'image'
local iproc = require './iproc' local iproc = require './iproc'
local function reconstruct_layer(model, x, offset, block_size) local function reconstruct_y(model, x, offset, block_size)
if x:dim() == 2 then if x:dim() == 2 then
x = x:reshape(1, x:size(1), x:size(2)) x = x:reshape(1, x:size(1), x:size(2))
end end
@ -26,8 +26,40 @@ local function reconstruct_layer(model, x, offset, block_size)
end end
return new_x return new_x
end end
local function reconstruct_rgb(model, x, offset, block_size)
local new_x = torch.Tensor():resizeAs(x):zero()
local output_size = block_size - offset * 2
local input = torch.CudaTensor(1, 3, block_size, block_size)
for i = 1, x:size(2), output_size do
for j = 1, x:size(3), output_size do
if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
local index = {{},
{i, i + block_size - 1},
{j, j + block_size - 1}}
input:copy(x[index])
local output = model:forward(input):float():view(3, output_size, output_size)
local output_index = {{},
{i + offset, offset + i + output_size - 1},
{offset + j, offset + j + output_size - 1}}
new_x[output_index]:copy(output)
end
end
end
return new_x
end
function model_is_rgb(model)
if model:get(model:size() - 1).weight:size(1) == 3 then
-- 3ch RGB
return true
else
-- 1ch Y
return false
end
end
local reconstruct = {} local reconstruct = {}
function reconstruct.image(model, x, offset, block_size) function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128 block_size = block_size or 128
local output_size = block_size - offset * 2 local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) + local h_blocks = math.floor(x:size(2) / output_size) +
@ -42,7 +74,7 @@ function reconstruct.image(model, x, offset, block_size)
local pad_h2 = (h - offset) - x:size(2) local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3) local pad_w2 = (w - offset) - x:size(3)
local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)) local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
local y = reconstruct_layer(model, yuv[1], offset, block_size) local y = reconstruct_y(model, yuv[1], offset, block_size)
y[torch.lt(y, 0)] = 0 y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1 y[torch.gt(y, 1)] = 1
yuv[1]:copy(y) yuv[1]:copy(y)
@ -55,7 +87,7 @@ function reconstruct.image(model, x, offset, block_size)
return output return output
end end
function reconstruct.scale(model, scale, x, offset, block_size) function reconstruct.scale_y(model, scale, x, offset, block_size)
block_size = block_size or 128 block_size = block_size or 128
local x_jinc = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Jinc") local x_jinc = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Jinc")
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box") x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
@ -74,7 +106,7 @@ function reconstruct.scale(model, scale, x, offset, block_size)
local pad_w2 = (w - offset) - x:size(3) local pad_w2 = (w - offset) - x:size(3)
local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)) local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2)) local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2))
local y = reconstruct_layer(model, yuv_nn[1], offset, block_size) local y = reconstruct_y(model, yuv_nn[1], offset, block_size)
y[torch.lt(y, 0)] = 0 y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1 y[torch.gt(y, 1)] = 1
yuv_jinc[1]:copy(y) yuv_jinc[1]:copy(y)
@ -87,5 +119,72 @@ function reconstruct.scale(model, scale, x, offset, block_size)
return output return output
end end
function reconstruct.image_rgb(model, x, offset, block_size)
block_size = block_size or 128
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
((x:size(2) % output_size == 0 and 0) or 1)
local w_blocks = math.floor(x:size(3) / output_size) +
((x:size(3) % output_size == 0 and 0) or 1)
local h = offset + h_blocks * output_size + offset
local w = offset + w_blocks * output_size + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
local y = reconstruct_rgb(model, input, offset, block_size)
local output = image.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
collectgarbage()
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
return output
end
function reconstruct.scale_rgb(model, scale, x, offset, block_size)
block_size = block_size or 128
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
local output_size = block_size - offset * 2
local h_blocks = math.floor(x:size(2) / output_size) +
((x:size(2) % output_size == 0 and 0) or 1)
local w_blocks = math.floor(x:size(3) / output_size) +
((x:size(3) % output_size == 0 and 0) or 1)
local h = offset + h_blocks * output_size + offset
local w = offset + w_blocks * output_size + offset
local pad_h1 = offset
local pad_w1 = offset
local pad_h2 = (h - offset) - x:size(2)
local pad_w2 = (w - offset) - x:size(3)
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
local y = reconstruct_rgb(model, input, offset, block_size)
local output = image.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
output[torch.lt(output, 0)] = 0
output[torch.gt(output, 1)] = 1
collectgarbage()
return output
end
function reconstruct.image(model, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.image_rgb(model, x, offset, block_size)
else
return reconstruct.image_y(model, x, offset, block_size)
end
end
function reconstruct.scale(model, scale, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
else
return reconstruct.scale_y(model, scale, x, offset, block_size)
end
end
return reconstruct return reconstruct

View file

@ -22,6 +22,7 @@ cmd:option("-test", "images/miku_small.png", 'test image file')
cmd:option("-model_dir", "./models", 'model directory') cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", '(noise|scale|noise_scale)') cmd:option("-method", "scale", '(noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)') cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-scale", 2.0, 'scale') cmd:option("-scale", 2.0, 'scale')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image') cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
@ -46,6 +47,9 @@ elseif settings.method == "noise_scale" then
else else
error("unknown method: " .. settings.method) error("unknown method: " .. settings.method)
end end
if not (settings.color == "rgb" or settings.color == "y") then
error("color must be y or rgb")
end
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
error("scale must be mod-2") error("scale must be mod-2")
end end

View file

@ -6,10 +6,22 @@ function nn.SpatialConvolutionMM:reset(stdv)
self.bias:fill(0) self.bias:fill(0)
end end
local srcnn = {} local srcnn = {}
function srcnn.waifu2x() function srcnn.waifu2x(color)
local model = nn.Sequential() local model = nn.Sequential()
local ch = nil
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
if color then
error("unknown color: " .. color)
else
error("unknown color: nil")
end
end
model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
@ -21,7 +33,7 @@ function srcnn.waifu2x()
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3)) model:add(nn.View(-1):setNumInputDims(3))
--model:cuda() --model:cuda()
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size()) --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
@ -30,10 +42,19 @@ function srcnn.waifu2x()
end end
-- current 4x is worse then 2x * 2 -- current 4x is worse then 2x * 2
function srcnn.waifu4x() function srcnn.waifu4x(color)
local model = nn.Sequential() local model = nn.Sequential()
local ch = nil
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
error("unknown color: " .. color)
end
model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
@ -45,7 +66,7 @@ function srcnn.waifu4x()
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1)) model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0)) model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3)) model:add(nn.View(-1):setNumInputDims(3))
return model, 13 return model, 13

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -59,7 +59,7 @@ local function validate(model, criterion, data)
end end
local function train() local function train()
local model, offset = settings.create_model() local model, offset = settings.create_model(settings.color)
assert(offset == settings.block_offset) assert(offset == settings.block_offset)
local criterion = nn.MSECriterion():cuda() local criterion = nn.MSECriterion():cuda()
local x = torch.load(settings.images) local x = torch.load(settings.images)
@ -72,6 +72,12 @@ local function train()
learningRate = settings.learning_rate, learningRate = settings.learning_rate,
xBatchSize = settings.batch_size, xBatchSize = settings.batch_size,
} }
local ch = nil
if settings.color == "y" then
ch = 1
elseif settings.color == "rgb" then
ch = 3
end
local transformer = function(x, is_validation) local transformer = function(x, is_validation)
if is_validation == nil then is_validation = false end if is_validation == nil then is_validation = false end
if settings.method == "scale" then if settings.method == "scale" then
@ -79,20 +85,25 @@ local function train()
settings.scale, settings.scale,
settings.crop_size, offset, settings.crop_size, offset,
{ color_augment = not is_validation, { color_augment = not is_validation,
random_half = settings.random_half}) random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise" then elseif settings.method == "noise" then
return pairwise_transform.jpeg(x, return pairwise_transform.jpeg(x,
settings.noise_level, settings.noise_level,
settings.crop_size, offset, settings.crop_size, offset,
{ color_augment = not is_validation, { color_augment = not is_validation,
random_half = settings.random_half}) random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise_scale" then elseif settings.method == "noise_scale" then
return pairwise_transform.jpeg_scale(x, return pairwise_transform.jpeg_scale(x,
settings.scale, settings.scale,
settings.noise_level, settings.noise_level,
settings.crop_size, offset, settings.crop_size, offset,
{ color_augment = not is_validation, { color_augment = not is_validation,
random_half = settings.random_half random_half = settings.random_half,
rgb = (settings.color == "rgb")
}) })
end end
end end
@ -109,8 +120,8 @@ local function train()
print("# " .. epoch) print("# " .. epoch)
print(minibatch_adam(model, criterion, train_x, adam_config, print(minibatch_adam(model, criterion, train_x, adam_config,
transformer, transformer,
{1, settings.crop_size, settings.crop_size}, {ch, settings.crop_size, settings.crop_size},
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2} {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
)) ))
model:evaluate() model:evaluate()
print("# validation") print("# validation")

View file

@ -1,10 +1,10 @@
#!/bin/sh #!/bin/sh
th train.lua -method noise -noise_level 1 -model_dir models/anime_style_art -test images/miku_noisy.png th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art/noise1_model.t7 -oformat ascii th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -method noise -noise_level 2 -model_dir models/anime_style_art -test images/miku_noisy.png th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art/noise2_model.t7 -oformat ascii th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
th train.lua -method scale -scale 2 -model_dir models/anime_style_art -test images/miku_small.png th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
th cleanup_model.lua -model models/anime_style_art/scale2.0x_model.t7 -oformat ascii th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii

View file

@ -105,7 +105,7 @@ local function waifu2x()
cmd:option("-l", "", 'path of the image-list') cmd:option("-l", "", 'path of the image-list')
cmd:option("-scale", 2, 'scale factor') cmd:option("-scale", 2, 'scale factor')
cmd:option("-o", "(auto)", 'path of the output file') cmd:option("-o", "(auto)", 'path of the output file')
cmd:option("-model_dir", "./models/anime_style_art", 'model directory') cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)') cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)') cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-crop_size", 128, 'patch size per process') cmd:option("-crop_size", 128, 'patch size per process')

View file

@ -23,7 +23,7 @@ local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct' local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader' local image_loader = require './lib/image_loader'
local MODEL_DIR = "./models/anime_style_art" local MODEL_DIR = "./models/anime_style_art_rgb"
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii") local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii") local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")