Merge branch 'master' of https://github.com/nagadomi/waifu2x
|
@ -5,7 +5,7 @@ require './lib/LeakyReLU'
|
||||||
local srcnn = require 'lib/srcnn'
|
local srcnn = require 'lib/srcnn'
|
||||||
|
|
||||||
local function cudnn2cunn(cudnn_model)
|
local function cudnn2cunn(cudnn_model)
|
||||||
local cunn_model = srcnn.waifu2x()
|
local cunn_model = srcnn.waifu2x("y")
|
||||||
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
|
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
|
||||||
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
|
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
|
||||||
|
|
||||||
|
|
Before Width: | Height: | Size: 383 KiB After Width: | Height: | Size: 315 KiB |
Before Width: | Height: | Size: 648 KiB After Width: | Height: | Size: 605 KiB |
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 154 KiB |
Before Width: | Height: | Size: 148 KiB After Width: | Height: | Size: 138 KiB |
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 136 KiB |
BIN
images/slide.odp
BIN
images/slide.png
Before Width: | Height: | Size: 1.2 MiB After Width: | Height: | Size: 1.2 MiB |
Before Width: | Height: | Size: 493 KiB After Width: | Height: | Size: 499 KiB |
Before Width: | Height: | Size: 372 KiB After Width: | Height: | Size: 380 KiB |
Before Width: | Height: | Size: 377 KiB After Width: | Height: | Size: 378 KiB |
|
@ -73,12 +73,18 @@ function image_loader.decode_byte(blob)
|
||||||
end
|
end
|
||||||
function image_loader.load_float(file)
|
function image_loader.load_float(file)
|
||||||
local fp = io.open(file, "rb")
|
local fp = io.open(file, "rb")
|
||||||
|
if not fp then
|
||||||
|
error(file .. ": failed to load image")
|
||||||
|
end
|
||||||
local buff = fp:read("*a")
|
local buff = fp:read("*a")
|
||||||
fp:close()
|
fp:close()
|
||||||
return image_loader.decode_float(buff)
|
return image_loader.decode_float(buff)
|
||||||
end
|
end
|
||||||
function image_loader.load_byte(file)
|
function image_loader.load_byte(file)
|
||||||
local fp = io.open(file, "rb")
|
local fp = io.open(file, "rb")
|
||||||
|
if not fp then
|
||||||
|
error(file .. ": failed to load image")
|
||||||
|
end
|
||||||
local buff = fp:read("*a")
|
local buff = fp:read("*a")
|
||||||
fp:close()
|
fp:close()
|
||||||
return image_loader.decode_byte(buff)
|
return image_loader.decode_byte(buff)
|
||||||
|
|
|
@ -2,15 +2,6 @@ local gm = require 'graphicsmagick'
|
||||||
local image = require 'image'
|
local image = require 'image'
|
||||||
local iproc = {}
|
local iproc = {}
|
||||||
|
|
||||||
function iproc.sample(src, width, height)
|
|
||||||
local t = "float"
|
|
||||||
if src:type() == "torch.ByteTensor" then
|
|
||||||
t = "byte"
|
|
||||||
end
|
|
||||||
local im = gm.Image(src, "RGB", "DHW")
|
|
||||||
im:sample(math.ceil(width), math.ceil(height))
|
|
||||||
return im:toTensor(t, "RGB", "DHW")
|
|
||||||
end
|
|
||||||
function iproc.scale(src, width, height, filter)
|
function iproc.scale(src, width, height, filter)
|
||||||
local t = "float"
|
local t = "float"
|
||||||
if src:type() == "torch.ByteTensor" then
|
if src:type() == "torch.ByteTensor" then
|
||||||
|
|
|
@ -52,7 +52,7 @@ local function flip_augment(x, y)
|
||||||
end
|
end
|
||||||
local INTERPOLATION_PADDING = 16
|
local INTERPOLATION_PADDING = 16
|
||||||
function pairwise_transform.scale(src, scale, size, offset, options)
|
function pairwise_transform.scale(src, scale, size, offset, options)
|
||||||
options = options or {color_augment = true, random_half = true}
|
options = options or {color_augment = true, random_half = true, rgb = true}
|
||||||
if options.random_half then
|
if options.random_half then
|
||||||
src = random_half(src)
|
src = random_half(src)
|
||||||
end
|
end
|
||||||
|
@ -81,8 +81,12 @@ function pairwise_transform.scale(src, scale, size, offset, options)
|
||||||
x = iproc.scale(x, y:size(3), y:size(2))
|
x = iproc.scale(x, y:size(3), y:size(2))
|
||||||
y = y:float():div(255)
|
y = y:float():div(255)
|
||||||
x = x:float():div(255)
|
x = x:float():div(255)
|
||||||
|
|
||||||
|
if options.rgb then
|
||||||
|
else
|
||||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||||
|
end
|
||||||
|
|
||||||
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
|
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
|
||||||
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
|
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
|
||||||
|
@ -90,7 +94,7 @@ function pairwise_transform.scale(src, scale, size, offset, options)
|
||||||
return x, y
|
return x, y
|
||||||
end
|
end
|
||||||
function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
||||||
options = options or {color_augment = true, random_half = true}
|
options = options or {color_augment = true, random_half = true, rgb = true}
|
||||||
if options.random_half then
|
if options.random_half then
|
||||||
src = random_half(src)
|
src = random_half(src)
|
||||||
end
|
end
|
||||||
|
@ -106,6 +110,7 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
||||||
for i = 1, #quality do
|
for i = 1, #quality do
|
||||||
x = gm.Image(x, "RGB", "DHW")
|
x = gm.Image(x, "RGB", "DHW")
|
||||||
x:format("jpeg")
|
x:format("jpeg")
|
||||||
|
x:samplingFactors({1.0, 1.0, 1.0})
|
||||||
local blob, len = x:toBlob(quality[i])
|
local blob, len = x:toBlob(quality[i])
|
||||||
x:fromBlob(blob, len)
|
x:fromBlob(blob, len)
|
||||||
x = x:toTensor("byte", "RGB", "DHW")
|
x = x:toTensor("byte", "RGB", "DHW")
|
||||||
|
@ -117,8 +122,11 @@ function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
||||||
x = x:float():div(255)
|
x = x:float():div(255)
|
||||||
x, y = flip_augment(x, y)
|
x, y = flip_augment(x, y)
|
||||||
|
|
||||||
|
if options.rgb then
|
||||||
|
else
|
||||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||||
|
end
|
||||||
|
|
||||||
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
||||||
end
|
end
|
||||||
|
@ -159,12 +167,12 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
|
||||||
local down_scale = 1.0 / scale
|
local down_scale = 1.0 / scale
|
||||||
local filters = {
|
local filters = {
|
||||||
"Box", -- 0.012756949974688
|
"Box", -- 0.012756949974688
|
||||||
--"Blackman", -- 0.013191924552285
|
"Blackman", -- 0.013191924552285
|
||||||
--"Cartom", -- 0.013753536746706
|
--"Cartom", -- 0.013753536746706
|
||||||
--"Hanning", -- 0.013761314529647
|
--"Hanning", -- 0.013761314529647
|
||||||
--"Hermite", -- 0.013850225205266
|
--"Hermite", -- 0.013850225205266
|
||||||
--"SincFast", -- 0.014095824314306
|
"SincFast", -- 0.014095824314306
|
||||||
--"Jinc", -- 0.014244299255442
|
"Jinc", -- 0.014244299255442
|
||||||
}
|
}
|
||||||
local downscale_filter = filters[torch.random(1, #filters)]
|
local downscale_filter = filters[torch.random(1, #filters)]
|
||||||
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
|
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
|
||||||
|
@ -180,6 +188,7 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
|
||||||
for i = 1, #quality do
|
for i = 1, #quality do
|
||||||
x = gm.Image(x, "RGB", "DHW")
|
x = gm.Image(x, "RGB", "DHW")
|
||||||
x:format("jpeg")
|
x:format("jpeg")
|
||||||
|
x:samplingFactors({1.0, 1.0, 1.0})
|
||||||
local blob, len = x:toBlob(quality[i])
|
local blob, len = x:toBlob(quality[i])
|
||||||
x:fromBlob(blob, len)
|
x:fromBlob(blob, len)
|
||||||
x = x:toTensor("byte", "RGB", "DHW")
|
x = x:toTensor("byte", "RGB", "DHW")
|
||||||
|
@ -195,8 +204,11 @@ function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, optio
|
||||||
y = y:float():div(255)
|
y = y:float():div(255)
|
||||||
x, y = flip_augment(x, y)
|
x, y = flip_augment(x, y)
|
||||||
|
|
||||||
|
if options.rgb then
|
||||||
|
else
|
||||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||||
|
end
|
||||||
|
|
||||||
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
||||||
end
|
end
|
||||||
|
@ -247,7 +259,7 @@ local function test_scale()
|
||||||
local loader = require './image_loader'
|
local loader = require './image_loader'
|
||||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||||
for i = 1, 9 do
|
for i = 1, 9 do
|
||||||
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true})
|
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
|
||||||
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
|
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||||
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
|
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||||
print(y:size(), x:size())
|
print(y:size(), x:size())
|
||||||
|
@ -272,8 +284,8 @@ local function test_jpeg_scale()
|
||||||
--print(x:mean(), y:mean())
|
--print(x:mean(), y:mean())
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
--test_jpeg()
|
|
||||||
--test_scale()
|
--test_scale()
|
||||||
|
--test_jpeg()
|
||||||
--test_jpeg_scale()
|
--test_jpeg_scale()
|
||||||
|
|
||||||
return pairwise_transform
|
return pairwise_transform
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
require 'image'
|
require 'image'
|
||||||
local iproc = require './iproc'
|
local iproc = require './iproc'
|
||||||
|
|
||||||
local function reconstruct_layer(model, x, offset, block_size)
|
local function reconstruct_y(model, x, offset, block_size)
|
||||||
if x:dim() == 2 then
|
if x:dim() == 2 then
|
||||||
x = x:reshape(1, x:size(1), x:size(2))
|
x = x:reshape(1, x:size(1), x:size(2))
|
||||||
end
|
end
|
||||||
|
@ -26,8 +26,40 @@ local function reconstruct_layer(model, x, offset, block_size)
|
||||||
end
|
end
|
||||||
return new_x
|
return new_x
|
||||||
end
|
end
|
||||||
|
local function reconstruct_rgb(model, x, offset, block_size)
|
||||||
|
local new_x = torch.Tensor():resizeAs(x):zero()
|
||||||
|
local output_size = block_size - offset * 2
|
||||||
|
local input = torch.CudaTensor(1, 3, block_size, block_size)
|
||||||
|
|
||||||
|
for i = 1, x:size(2), output_size do
|
||||||
|
for j = 1, x:size(3), output_size do
|
||||||
|
if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then
|
||||||
|
local index = {{},
|
||||||
|
{i, i + block_size - 1},
|
||||||
|
{j, j + block_size - 1}}
|
||||||
|
input:copy(x[index])
|
||||||
|
local output = model:forward(input):float():view(3, output_size, output_size)
|
||||||
|
local output_index = {{},
|
||||||
|
{i + offset, offset + i + output_size - 1},
|
||||||
|
{offset + j, offset + j + output_size - 1}}
|
||||||
|
new_x[output_index]:copy(output)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return new_x
|
||||||
|
end
|
||||||
|
function model_is_rgb(model)
|
||||||
|
if model:get(model:size() - 1).weight:size(1) == 3 then
|
||||||
|
-- 3ch RGB
|
||||||
|
return true
|
||||||
|
else
|
||||||
|
-- 1ch Y
|
||||||
|
return false
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
local reconstruct = {}
|
local reconstruct = {}
|
||||||
function reconstruct.image(model, x, offset, block_size)
|
function reconstruct.image_y(model, x, offset, block_size)
|
||||||
block_size = block_size or 128
|
block_size = block_size or 128
|
||||||
local output_size = block_size - offset * 2
|
local output_size = block_size - offset * 2
|
||||||
local h_blocks = math.floor(x:size(2) / output_size) +
|
local h_blocks = math.floor(x:size(2) / output_size) +
|
||||||
|
@ -42,7 +74,7 @@ function reconstruct.image(model, x, offset, block_size)
|
||||||
local pad_h2 = (h - offset) - x:size(2)
|
local pad_h2 = (h - offset) - x:size(2)
|
||||||
local pad_w2 = (w - offset) - x:size(3)
|
local pad_w2 = (w - offset) - x:size(3)
|
||||||
local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
|
local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
|
||||||
local y = reconstruct_layer(model, yuv[1], offset, block_size)
|
local y = reconstruct_y(model, yuv[1], offset, block_size)
|
||||||
y[torch.lt(y, 0)] = 0
|
y[torch.lt(y, 0)] = 0
|
||||||
y[torch.gt(y, 1)] = 1
|
y[torch.gt(y, 1)] = 1
|
||||||
yuv[1]:copy(y)
|
yuv[1]:copy(y)
|
||||||
|
@ -55,7 +87,7 @@ function reconstruct.image(model, x, offset, block_size)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
end
|
end
|
||||||
function reconstruct.scale(model, scale, x, offset, block_size)
|
function reconstruct.scale_y(model, scale, x, offset, block_size)
|
||||||
block_size = block_size or 128
|
block_size = block_size or 128
|
||||||
local x_jinc = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Jinc")
|
local x_jinc = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Jinc")
|
||||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
||||||
|
@ -74,7 +106,7 @@ function reconstruct.scale(model, scale, x, offset, block_size)
|
||||||
local pad_w2 = (w - offset) - x:size(3)
|
local pad_w2 = (w - offset) - x:size(3)
|
||||||
local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
|
local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2))
|
||||||
local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2))
|
local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2))
|
||||||
local y = reconstruct_layer(model, yuv_nn[1], offset, block_size)
|
local y = reconstruct_y(model, yuv_nn[1], offset, block_size)
|
||||||
y[torch.lt(y, 0)] = 0
|
y[torch.lt(y, 0)] = 0
|
||||||
y[torch.gt(y, 1)] = 1
|
y[torch.gt(y, 1)] = 1
|
||||||
yuv_jinc[1]:copy(y)
|
yuv_jinc[1]:copy(y)
|
||||||
|
@ -87,5 +119,72 @@ function reconstruct.scale(model, scale, x, offset, block_size)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
end
|
end
|
||||||
|
function reconstruct.image_rgb(model, x, offset, block_size)
|
||||||
|
block_size = block_size or 128
|
||||||
|
local output_size = block_size - offset * 2
|
||||||
|
local h_blocks = math.floor(x:size(2) / output_size) +
|
||||||
|
((x:size(2) % output_size == 0 and 0) or 1)
|
||||||
|
local w_blocks = math.floor(x:size(3) / output_size) +
|
||||||
|
((x:size(3) % output_size == 0 and 0) or 1)
|
||||||
|
|
||||||
|
local h = offset + h_blocks * output_size + offset
|
||||||
|
local w = offset + w_blocks * output_size + offset
|
||||||
|
local pad_h1 = offset
|
||||||
|
local pad_w1 = offset
|
||||||
|
local pad_h2 = (h - offset) - x:size(2)
|
||||||
|
local pad_w2 = (w - offset) - x:size(3)
|
||||||
|
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||||
|
local y = reconstruct_rgb(model, input, offset, block_size)
|
||||||
|
local output = image.crop(y,
|
||||||
|
pad_w1, pad_h1,
|
||||||
|
y:size(3) - pad_w2, y:size(2) - pad_h2)
|
||||||
|
collectgarbage()
|
||||||
|
output[torch.lt(output, 0)] = 0
|
||||||
|
output[torch.gt(output, 1)] = 1
|
||||||
|
|
||||||
|
return output
|
||||||
|
end
|
||||||
|
function reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||||
|
block_size = block_size or 128
|
||||||
|
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
||||||
|
|
||||||
|
local output_size = block_size - offset * 2
|
||||||
|
local h_blocks = math.floor(x:size(2) / output_size) +
|
||||||
|
((x:size(2) % output_size == 0 and 0) or 1)
|
||||||
|
local w_blocks = math.floor(x:size(3) / output_size) +
|
||||||
|
((x:size(3) % output_size == 0 and 0) or 1)
|
||||||
|
|
||||||
|
local h = offset + h_blocks * output_size + offset
|
||||||
|
local w = offset + w_blocks * output_size + offset
|
||||||
|
local pad_h1 = offset
|
||||||
|
local pad_w1 = offset
|
||||||
|
local pad_h2 = (h - offset) - x:size(2)
|
||||||
|
local pad_w2 = (w - offset) - x:size(3)
|
||||||
|
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||||
|
local y = reconstruct_rgb(model, input, offset, block_size)
|
||||||
|
local output = image.crop(y,
|
||||||
|
pad_w1, pad_h1,
|
||||||
|
y:size(3) - pad_w2, y:size(2) - pad_h2)
|
||||||
|
output[torch.lt(output, 0)] = 0
|
||||||
|
output[torch.gt(output, 1)] = 1
|
||||||
|
collectgarbage()
|
||||||
|
|
||||||
|
return output
|
||||||
|
end
|
||||||
|
|
||||||
|
function reconstruct.image(model, x, offset, block_size)
|
||||||
|
if model_is_rgb(model) then
|
||||||
|
return reconstruct.image_rgb(model, x, offset, block_size)
|
||||||
|
else
|
||||||
|
return reconstruct.image_y(model, x, offset, block_size)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function reconstruct.scale(model, scale, x, offset, block_size)
|
||||||
|
if model_is_rgb(model) then
|
||||||
|
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||||
|
else
|
||||||
|
return reconstruct.scale_y(model, scale, x, offset, block_size)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
return reconstruct
|
return reconstruct
|
||||||
|
|
|
@ -22,6 +22,7 @@ cmd:option("-test", "images/miku_small.png", 'test image file')
|
||||||
cmd:option("-model_dir", "./models", 'model directory')
|
cmd:option("-model_dir", "./models", 'model directory')
|
||||||
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
|
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
|
||||||
cmd:option("-noise_level", 1, '(1|2)')
|
cmd:option("-noise_level", 1, '(1|2)')
|
||||||
|
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||||
cmd:option("-scale", 2.0, 'scale')
|
cmd:option("-scale", 2.0, 'scale')
|
||||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||||
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
|
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
|
||||||
|
@ -46,6 +47,9 @@ elseif settings.method == "noise_scale" then
|
||||||
else
|
else
|
||||||
error("unknown method: " .. settings.method)
|
error("unknown method: " .. settings.method)
|
||||||
end
|
end
|
||||||
|
if not (settings.color == "rgb" or settings.color == "y") then
|
||||||
|
error("color must be y or rgb")
|
||||||
|
end
|
||||||
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
||||||
error("scale must be mod-2")
|
error("scale must be mod-2")
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,10 +6,22 @@ function nn.SpatialConvolutionMM:reset(stdv)
|
||||||
self.bias:fill(0)
|
self.bias:fill(0)
|
||||||
end
|
end
|
||||||
local srcnn = {}
|
local srcnn = {}
|
||||||
function srcnn.waifu2x()
|
function srcnn.waifu2x(color)
|
||||||
local model = nn.Sequential()
|
local model = nn.Sequential()
|
||||||
|
local ch = nil
|
||||||
|
if color == "rgb" then
|
||||||
|
ch = 3
|
||||||
|
elseif color == "y" then
|
||||||
|
ch = 1
|
||||||
|
else
|
||||||
|
if color then
|
||||||
|
error("unknown color: " .. color)
|
||||||
|
else
|
||||||
|
error("unknown color: nil")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
|
@ -21,7 +33,7 @@ function srcnn.waifu2x()
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.View(-1):setNumInputDims(3))
|
model:add(nn.View(-1):setNumInputDims(3))
|
||||||
--model:cuda()
|
--model:cuda()
|
||||||
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
|
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
|
||||||
|
@ -30,10 +42,19 @@ function srcnn.waifu2x()
|
||||||
end
|
end
|
||||||
|
|
||||||
-- current 4x is worse then 2x * 2
|
-- current 4x is worse then 2x * 2
|
||||||
function srcnn.waifu4x()
|
function srcnn.waifu4x(color)
|
||||||
local model = nn.Sequential()
|
local model = nn.Sequential()
|
||||||
|
|
||||||
model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0))
|
local ch = nil
|
||||||
|
if color == "rgb" then
|
||||||
|
ch = 3
|
||||||
|
elseif color == "y" then
|
||||||
|
ch = 1
|
||||||
|
else
|
||||||
|
error("unknown color: " .. color)
|
||||||
|
end
|
||||||
|
|
||||||
|
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
|
@ -45,7 +66,7 @@ function srcnn.waifu4x()
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||||
model:add(nn.LeakyReLU(0.1))
|
model:add(nn.LeakyReLU(0.1))
|
||||||
model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0))
|
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
|
||||||
model:add(nn.View(-1):setNumInputDims(3))
|
model:add(nn.View(-1):setNumInputDims(3))
|
||||||
|
|
||||||
return model, 13
|
return model, 13
|
||||||
|
|
BIN
models/anime_style_art_rgb/noise1_model.json
Normal file
BIN
models/anime_style_art_rgb/noise1_model.t7
Normal file
BIN
models/anime_style_art_rgb/noise2_model.json
Normal file
BIN
models/anime_style_art_rgb/noise2_model.t7
Normal file
BIN
models/anime_style_art_rgb/scale2.0x_model.json
Normal file
BIN
models/anime_style_art_rgb/scale2.0x_model.t7
Normal file
23
train.lua
|
@ -59,7 +59,7 @@ local function validate(model, criterion, data)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function train()
|
local function train()
|
||||||
local model, offset = settings.create_model()
|
local model, offset = settings.create_model(settings.color)
|
||||||
assert(offset == settings.block_offset)
|
assert(offset == settings.block_offset)
|
||||||
local criterion = nn.MSECriterion():cuda()
|
local criterion = nn.MSECriterion():cuda()
|
||||||
local x = torch.load(settings.images)
|
local x = torch.load(settings.images)
|
||||||
|
@ -72,6 +72,12 @@ local function train()
|
||||||
learningRate = settings.learning_rate,
|
learningRate = settings.learning_rate,
|
||||||
xBatchSize = settings.batch_size,
|
xBatchSize = settings.batch_size,
|
||||||
}
|
}
|
||||||
|
local ch = nil
|
||||||
|
if settings.color == "y" then
|
||||||
|
ch = 1
|
||||||
|
elseif settings.color == "rgb" then
|
||||||
|
ch = 3
|
||||||
|
end
|
||||||
local transformer = function(x, is_validation)
|
local transformer = function(x, is_validation)
|
||||||
if is_validation == nil then is_validation = false end
|
if is_validation == nil then is_validation = false end
|
||||||
if settings.method == "scale" then
|
if settings.method == "scale" then
|
||||||
|
@ -79,20 +85,25 @@ local function train()
|
||||||
settings.scale,
|
settings.scale,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
{ color_augment = not is_validation,
|
{ color_augment = not is_validation,
|
||||||
random_half = settings.random_half})
|
random_half = settings.random_half,
|
||||||
|
rgb = (settings.color == "rgb")
|
||||||
|
})
|
||||||
elseif settings.method == "noise" then
|
elseif settings.method == "noise" then
|
||||||
return pairwise_transform.jpeg(x,
|
return pairwise_transform.jpeg(x,
|
||||||
settings.noise_level,
|
settings.noise_level,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
{ color_augment = not is_validation,
|
{ color_augment = not is_validation,
|
||||||
random_half = settings.random_half})
|
random_half = settings.random_half,
|
||||||
|
rgb = (settings.color == "rgb")
|
||||||
|
})
|
||||||
elseif settings.method == "noise_scale" then
|
elseif settings.method == "noise_scale" then
|
||||||
return pairwise_transform.jpeg_scale(x,
|
return pairwise_transform.jpeg_scale(x,
|
||||||
settings.scale,
|
settings.scale,
|
||||||
settings.noise_level,
|
settings.noise_level,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
{ color_augment = not is_validation,
|
{ color_augment = not is_validation,
|
||||||
random_half = settings.random_half
|
random_half = settings.random_half,
|
||||||
|
rgb = (settings.color == "rgb")
|
||||||
})
|
})
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -109,8 +120,8 @@ local function train()
|
||||||
print("# " .. epoch)
|
print("# " .. epoch)
|
||||||
print(minibatch_adam(model, criterion, train_x, adam_config,
|
print(minibatch_adam(model, criterion, train_x, adam_config,
|
||||||
transformer,
|
transformer,
|
||||||
{1, settings.crop_size, settings.crop_size},
|
{ch, settings.crop_size, settings.crop_size},
|
||||||
{1, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
||||||
))
|
))
|
||||||
model:evaluate()
|
model:evaluate()
|
||||||
print("# validation")
|
print("# validation")
|
||||||
|
|
12
train.sh
|
@ -1,10 +1,10 @@
|
||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
th train.lua -method noise -noise_level 1 -model_dir models/anime_style_art -test images/miku_noisy.png
|
th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
|
||||||
th cleanup_model.lua -model models/anime_style_art/noise1_model.t7 -oformat ascii
|
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||||
|
|
||||||
th train.lua -method noise -noise_level 2 -model_dir models/anime_style_art -test images/miku_noisy.png
|
th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
|
||||||
th cleanup_model.lua -model models/anime_style_art/noise2_model.t7 -oformat ascii
|
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||||
|
|
||||||
th train.lua -method scale -scale 2 -model_dir models/anime_style_art -test images/miku_small.png
|
th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
|
||||||
th cleanup_model.lua -model models/anime_style_art/scale2.0x_model.t7 -oformat ascii
|
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||||
|
|
|
@ -105,7 +105,7 @@ local function waifu2x()
|
||||||
cmd:option("-l", "", 'path of the image-list')
|
cmd:option("-l", "", 'path of the image-list')
|
||||||
cmd:option("-scale", 2, 'scale factor')
|
cmd:option("-scale", 2, 'scale factor')
|
||||||
cmd:option("-o", "(auto)", 'path of the output file')
|
cmd:option("-o", "(auto)", 'path of the output file')
|
||||||
cmd:option("-model_dir", "./models/anime_style_art", 'model directory')
|
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
|
||||||
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
||||||
cmd:option("-noise_level", 1, '(1|2)')
|
cmd:option("-noise_level", 1, '(1|2)')
|
||||||
cmd:option("-crop_size", 128, 'patch size per process')
|
cmd:option("-crop_size", 128, 'patch size per process')
|
||||||
|
|
2
web.lua
|
@ -23,7 +23,7 @@ local iproc = require './lib/iproc'
|
||||||
local reconstruct = require './lib/reconstruct'
|
local reconstruct = require './lib/reconstruct'
|
||||||
local image_loader = require './lib/image_loader'
|
local image_loader = require './lib/image_loader'
|
||||||
|
|
||||||
local MODEL_DIR = "./models/anime_style_art"
|
local MODEL_DIR = "./models/anime_style_art_rgb"
|
||||||
|
|
||||||
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||||
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
|
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||||
|
|