diff --git a/README.md b/README.md index b26ca41..f26b99f 100644 --- a/README.md +++ b/README.md @@ -23,24 +23,20 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed ## Dependencies ### Hardware -- NVIDIA GPU (Compute Capability 3.0 or later) +- NVIDIA GPU ### Platform - [Torch7](http://torch.ch/) - [NVIDIA CUDA](https://developer.nvidia.com/cuda-toolkit) -- [NVIDIA cuDNN](https://developer.nvidia.com/cuDNN) ### Packages (luarocks) - cutorch - cunn -- [cudnn](https://github.com/soumith/cudnn.torch) - [graphicsmagick](https://github.com/clementfarabet/graphicsmagick) - [turbo](https://github.com/kernelsauce/turbo) - md5 - uuid -NOTE: Turbo 1.1.3 has bug in file uploading. Please install from the master branch on github. - ## Installation ### Setting Up the Command Line Tool Environment @@ -54,16 +50,15 @@ curl -s https://raw.githubusercontent.com/torch/ezinstall/master/install-all | s ``` see [Torch (easy) install](https://github.com/torch/ezinstall) -#### Install CUDA and cuDNN. +#### Install CUDA -Google! Search keyword is "install cuda ubuntu" and "install cudnn ubuntu" +Google! Search keyword: "install cuda ubuntu" #### Install packages ``` sudo luarocks install cutorch sudo luarocks install cunn -sudo luarocks install cudnn sudo apt-get install graphicsmagick libgraphicsmagick-dev sudo luarocks install graphicsmagick ``` @@ -91,21 +86,10 @@ Install luarocks packages. ``` sudo luarocks install md5 sudo luarocks install uuid -``` - -Install turbo. -``` -git clone https://github.com/kernelsauce/turbo.git -cd turbo -sudo luarocks make rockspecs/turbo-dev-1.rockspec +sudo luarocks install turbo ``` ## Web Application - -Please edit the first line in `web.lua`. -``` -local ROOT = '/path/to/waifu2x/dir' -``` Run. ``` th web.lua @@ -173,7 +157,7 @@ Genrating a file list. ``` find /path/to/image/dir -name "*.png" > data/image_list.txt ``` -(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-beautiful-PNG images.) +(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-noise-free-PNG images.) Converting training data. ``` @@ -183,23 +167,30 @@ th convert_data.lua ### Training a Noise Reduction(level1) model ``` -th train.lua -method noise -noise_level 1 -test images/miku_noisy.png -th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii +mkdir models/my_model +th train.lua -model_dir models/my_model -method noise -noise_level 1 -test images/miku_noisy.png +th cleanup_model.lua -model models/my_model/noise1_model.t7 -oformat ascii +# usage +th waifu2x.lua -model_dir models/my_model -m noise -noise_level 1 -i images/miku_noisy.png -o output.png ``` -You can check the performance of model with `models/noise1_best.png`. +You can check the performance of model with `models/my_model/noise1_best.png`. ### Training a Noise Reduction(level2) model ``` -th train.lua -method noise -noise_level 2 -test images/miku_noisy.png -th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii +th train.lua -model_dir models/my_model -method noise -noise_level 2 -test images/miku_noisy.png +th cleanup_model.lua -model models/my_model/noise2_model.t7 -oformat ascii +# usage +th waifu2x.lua -model_dir models/my_model -m noise -noise_level 2 -i images/miku_noisy.png -o output.png ``` -You can check the performance of model with `models/noise2_best.png`. +You can check the performance of model with `models/my_model/noise2_best.png`. ### Training a 2x UpScaling model ``` -th train.lua -method scale -scale 2 -test images/miku_small.png -th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii +th train.lua -model_dir models/my_model -method scale -scale 2 -test images/miku_small.png +th cleanup_model.lua -model models/my_model/scale2.0x_model.t7 -oformat ascii +# usage +th waifu2x.lua -model_dir models/my_model -m scale -scale 2 -i images/miku_small.png -o output.png ``` -You can check the performance of model with `models/scale2.0x_best.png`. +You can check the performance of model with `models/my_model/scale2.0x_best.png`. diff --git a/cleanup_model.lua b/cleanup_model.lua index abbaac2..2f91484 100644 --- a/cleanup_model.lua +++ b/cleanup_model.lua @@ -1,5 +1,4 @@ -require 'cunn' -require 'cudnn' +require './lib/portable' require './lib/LeakyReLU' torch.setdefaulttensortype("torch.FloatTensor") diff --git a/convert_data.lua b/convert_data.lua index b81869a..0fb9833 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -1,4 +1,5 @@ -require 'torch' +require './lib/portable' +require 'image' local settings = require './lib/settings' local image_loader = require './lib/image_loader' @@ -13,15 +14,21 @@ local function count_lines(file) return count end +local function crop_4x(x) + local w = x:size(3) % 4 + local h = x:size(2) % 4 + return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h) +end + local function load_images(list) local count = count_lines(list) local fp = io.open(list, "r") local x = {} local c = 0 for line in fp:lines() do - local im = image_loader.load_byte(line) + local im = crop_4x(image_loader.load_byte(line)) if im then - if im:size(2) > settings.crop_size * 2 and im:size(3) > settings.crop_size * 2 then + if im:size(2) >= settings.crop_size * 2 and im:size(3) >= settings.crop_size * 2 then table.insert(x, im) end else diff --git a/cudnn2cunn.lua b/cudnn2cunn.lua new file mode 100644 index 0000000..f35148a --- /dev/null +++ b/cudnn2cunn.lua @@ -0,0 +1,34 @@ +require 'cunn' +require 'cudnn' +require 'cutorch' +require './lib/LeakyReLU' +local srcnn = require 'lib/srcnn' + +local function cudnn2cunn(cudnn_model) + local cunn_model = srcnn.waifu2x() + local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution") + local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM") + + for i = 1, #from_seq do + local from = from_seq[i] + local to = to_seq[i] + to.weight:copy(from.weight) + to.bias:copy(from.bias) + end + cunn_model:cuda() + cunn_model:evaluate() + return cunn_model +end + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("convert cudnn model to cunn model ") +cmd:text("Options:") +cmd:option("-model", "./model.t7", 'path of cudnn model file') +cmd:option("-iformat", "ascii", 'input format') +cmd:option("-oformat", "ascii", 'output format') + +local opt = cmd:parse(arg) +local cudnn_model = torch.load(opt.model, opt.iformat) +local cunn_model = cudnn2cunn(cudnn_model) +torch.save(opt.model, cunn_model, opt.oformat) diff --git a/export_model.lua b/export_model.lua new file mode 100644 index 0000000..bb91a0d --- /dev/null +++ b/export_model.lua @@ -0,0 +1,23 @@ +-- adapted from https://github.com/marcan/cl-waifu2x +require './lib/portable' +require './lib/LeakyReLU' +local cjson = require "cjson" + +local model = torch.load(arg[1], "ascii") + +local jmodules = {} +local modules = model:findModules("nn.SpatialConvolutionMM") +for i = 1, #modules, 1 do + local module = modules[i] + local jmod = { + kW = module.kW, + kH = module.kH, + nInputPlane = module.nInputPlane, + nOutputPlane = module.nOutputPlane, + bias = torch.totable(module.bias:float()), + weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH)) + } + table.insert(jmodules, jmod) +end + +io.write(cjson.encode(jmodules)) diff --git a/images/lena_waifu2x.png b/images/lena_waifu2x.png index 5d9ea48..036892c 100644 Binary files a/images/lena_waifu2x.png and b/images/lena_waifu2x.png differ diff --git a/images/miku_CC_BY-NC_noisy_waifu2x.png b/images/miku_CC_BY-NC_noisy_waifu2x.png index 5104e46..aa5ee50 100644 Binary files a/images/miku_CC_BY-NC_noisy_waifu2x.png and b/images/miku_CC_BY-NC_noisy_waifu2x.png differ diff --git a/images/miku_noisy_waifu2x.png b/images/miku_noisy_waifu2x.png index 43760d5..063ecfa 100644 Binary files a/images/miku_noisy_waifu2x.png and b/images/miku_noisy_waifu2x.png differ diff --git a/images/miku_small_noisy_waifu2x.png b/images/miku_small_noisy_waifu2x.png index 73a40d6..6f0f073 100644 Binary files a/images/miku_small_noisy_waifu2x.png and b/images/miku_small_noisy_waifu2x.png differ diff --git a/images/miku_small_waifu2x.png b/images/miku_small_waifu2x.png index dd8f36e..60068f0 100644 Binary files a/images/miku_small_waifu2x.png and b/images/miku_small_waifu2x.png differ diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index 20b1c50..017926f 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -4,84 +4,103 @@ local iproc = require './iproc' local reconstruct = require './reconstruct' local pairwise_transform = {} -function pairwise_transform.scale(src, scale, size, offset, options) - options = options or {} - local yi = torch.random(0, src:size(2) - size - 1) - local xi = torch.random(0, src:size(3) - size - 1) - local down_scale = 1.0 / scale - local y = image.crop(src, xi, yi, xi + size, yi + size) +local function random_half(src, p, min_size) + p = p or 0.5 + local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)] + if p > torch.uniform() then + return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) + else + return src + end +end +local function color_augment(x) + local color_scale = torch.Tensor(3):uniform(0.8, 1.2) + x = x:float():div(255) + for i = 1, 3 do + x[i]:mul(color_scale[i]) + end + x[torch.lt(x, 0.0)] = 0.0 + x[torch.gt(x, 1.0)] = 1.0 + return x:mul(255):byte() +end +local function flip_augment(x, y) local flip = torch.random(1, 4) - local nega = torch.random(0, 1) + if y then + if flip == 1 then + x = image.hflip(x) + y = image.hflip(y) + elseif flip == 2 then + x = image.vflip(x) + y = image.vflip(y) + elseif flip == 3 then + x = image.hflip(image.vflip(x)) + y = image.hflip(image.vflip(y)) + elseif flip == 4 then + end + return x, y + else + if flip == 1 then + x = image.hflip(x) + elseif flip == 2 then + x = image.vflip(x) + elseif flip == 3 then + x = image.hflip(image.vflip(x)) + elseif flip == 4 then + end + return x + end +end +local INTERPOLATION_PADDING = 16 +function pairwise_transform.scale(src, scale, size, offset, options) + options = options or {color_augment = true, random_half = true} + if options.random_half then + src = random_half(src) + end + local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING) + local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING) + local down_scale = 1.0 / scale + local y = image.crop(src, + xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING, + xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING) local filters = { "Box", -- 0.012756949974688 "Blackman", -- 0.013191924552285 --"Cartom", -- 0.013753536746706 --"Hanning", -- 0.013761314529647 --"Hermite", -- 0.013850225205266 - --"SincFast", -- 0.014095824314306 - --"Jinc", -- 0.014244299255442 + "SincFast", -- 0.014095824314306 + "Jinc", -- 0.014244299255442 } local downscale_filter = filters[torch.random(1, #filters)] - if flip == 1 then - y = image.hflip(y) - elseif flip == 2 then - y = image.vflip(y) - elseif flip == 3 then - y = image.hflip(image.vflip(y)) - elseif flip == 4 then - -- none - end + y = flip_augment(y) if options.color_augment then - y = y:float():div(255) - local color_scale = torch.Tensor(3):uniform(0.8, 1.2) - for i = 1, 3 do - y[i]:mul(color_scale[i]) - end - y[torch.lt(y, 0)] = 0 - y[torch.gt(y, 1.0)] = 1.0 - y = y:mul(255):byte() + y = color_augment(y) end local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) - if options.noise and (options.noise_ratio or 0.5) > torch.uniform() then - -- add noise - local quality = {torch.random(70, 90)} - for i = 1, #quality do - x = gm.Image(x, "RGB", "DHW") - x:format("jpeg") - local blob, len = x:toBlob(quality[i]) - x:fromBlob(blob, len) - x = x:toTensor("byte", "RGB", "DHW") - end - end - if options.denoise_model and (options.denoise_ratio or 0.5) > torch.uniform() then - x = reconstruct(options.denoise_model, x:float():div(255), offset):mul(255):byte() - end x = iproc.scale(x, y:size(3), y:size(2)) y = y:float():div(255) x = x:float():div(255) y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) + + y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING) + x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING) - return x, image.crop(y, offset, offset, size - offset, size - offset) + return x, y end -function pairwise_transform.jpeg_(src, quality, size, offset, color_augment) - if color_augment == nil then color_augment = true end +function pairwise_transform.jpeg_(src, quality, size, offset, options) + options = options or {color_augment = true, random_half = true} + if options.random_half then + src = random_half(src) + end local yi = torch.random(0, src:size(2) - size - 1) local xi = torch.random(0, src:size(3) - size - 1) local y = src local x - local flip = torch.random(1, 4) - if color_augment then - local color_scale = torch.Tensor(3):uniform(0.8, 1.2) - y = y:float():div(255) - for i = 1, 3 do - y[i]:mul(color_scale[i]) - end - y[torch.lt(y, 0)] = 0 - y[torch.gt(y, 1.0)] = 1.0 - y = y:mul(255):byte() + if options.color_augment then + y = color_augment(y) end x = y for i = 1, #quality do @@ -94,48 +113,115 @@ function pairwise_transform.jpeg_(src, quality, size, offset, color_augment) y = image.crop(y, xi, yi, xi + size, yi + size) x = image.crop(x, xi, yi, xi + size, yi + size) - x = x:float():div(255) y = y:float():div(255) + x = x:float():div(255) + x, y = flip_augment(x, y) - if flip == 1 then - y = image.hflip(y) - x = image.hflip(x) - elseif flip == 2 then - y = image.vflip(y) - x = image.vflip(x) - elseif flip == 3 then - y = image.hflip(image.vflip(y)) - x = image.hflip(image.vflip(x)) - elseif flip == 4 then - -- none - end y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) return x, image.crop(y, offset, offset, size - offset, size - offset) end -function pairwise_transform.jpeg(src, level, size, offset, color_augment) +function pairwise_transform.jpeg(src, level, size, offset, options) if level == 1 then return pairwise_transform.jpeg_(src, {torch.random(65, 85)}, size, offset, - color_augment) + options) elseif level == 2 then local r = torch.uniform() if r > 0.6 then - return pairwise_transform.jpeg_(src, {torch.random(27, 80)}, + return pairwise_transform.jpeg_(src, {torch.random(27, 70)}, size, offset, - color_augment) + options) elseif r > 0.3 then - local quality1 = torch.random(32, 40) - local quality2 = quality1 - 5 + local quality1 = torch.random(37, 70) + local quality2 = quality1 - torch.random(5, 10) return pairwise_transform.jpeg_(src, {quality1, quality2}, - size, offset, - color_augment) + size, offset, + options) else - local quality1 = torch.random(47, 70) - return pairwise_transform.jpeg_(src, {quality1, quality1 - 10, quality1 - 20}, + local quality1 = torch.random(52, 70) + return pairwise_transform.jpeg_(src, + {quality1, + quality1 - torch.random(5, 15), + quality1 - torch.random(15, 25)}, size, offset, - color_augment) + options) + end + else + error("unknown noise level: " .. level) + end +end +function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options) + if options.random_half then + src = random_half(src) + end + local down_scale = 1.0 / scale + local filters = { + "Box", -- 0.012756949974688 + --"Blackman", -- 0.013191924552285 + --"Cartom", -- 0.013753536746706 + --"Hanning", -- 0.013761314529647 + --"Hermite", -- 0.013850225205266 + --"SincFast", -- 0.014095824314306 + --"Jinc", -- 0.014244299255442 + } + local downscale_filter = filters[torch.random(1, #filters)] + local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING) + local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING) + local y = src + local x + + if options.color_augment then + y = color_augment(y) + end + x = y + x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) + for i = 1, #quality do + x = gm.Image(x, "RGB", "DHW") + x:format("jpeg") + local blob, len = x:toBlob(quality[i]) + x:fromBlob(blob, len) + x = x:toTensor("byte", "RGB", "DHW") + end + x = iproc.scale(x, y:size(3), y:size(2)) + y = image.crop(y, + xi, yi, + xi + size, yi + size) + x = image.crop(x, + xi, yi, + xi + size, yi + size) + x = x:float():div(255) + y = y:float():div(255) + x, y = flip_augment(x, y) + + y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) + x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) + + return x, image.crop(y, offset, offset, size - offset, size - offset) +end +function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options) + options = options or {color_augment = true, random_half = true} + if level == 1 then + return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)}, + size, offset, options) + elseif level == 2 then + local r = torch.uniform() + if r > 0.6 then + return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)}, + size, offset, options) + elseif r > 0.3 then + local quality1 = torch.random(37, 70) + local quality2 = quality1 - torch.random(5, 10) + return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2}, + size, offset, options) + else + local quality1 = torch.random(52, 70) + return pairwise_transform.jpeg_scale_(src, scale, + {quality1, + quality1 - torch.random(5, 15), + quality1 - torch.random(15, 25)}, + size, offset, options) end else error("unknown noise level: " .. level) @@ -143,32 +229,51 @@ function pairwise_transform.jpeg(src, level, size, offset, color_augment) end local function test_jpeg() - local loader = require 'image_loader' - local src = loader.load_byte("a.jpg") - + local loader = require './image_loader' + local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") + local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false) + image.display({image = y, legend = "y:0"}) + image.display({image = x, legend = "x:0"}) for i = 2, 9 do - local y, x = pairwise_transform.jpeg_(src, {i * 10}, 128, 0, false) + local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src), + {i * 10}, 128, 0, {color_augment = false, random_half = true}) image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0}) image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0}) --print(x:mean(), y:mean()) end end -local function test_scale() - require 'nn' - require 'cudnn' - require './LeakyReLU' - - local loader = require 'image_loader' - local src = loader.load_byte("e.jpg") +local function test_scale() + local loader = require './image_loader' + local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") for i = 1, 9 do - local y, x = pairwise_transform.scale(src, 2.0, "Box", 128, 7, {noise = true, denoise_model = torch.load("models/noise1_model.t7")}) - image.display({image = y, legend = "y:" .. (i * 10)}) - image.display({image = x, legend = "x:" .. (i * 10)}) + local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true}) + image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1}) + image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1}) + print(y:size(), x:size()) + --print(x:mean(), y:mean()) + end +end +local function test_jpeg_scale() + local loader = require './image_loader' + local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") + for i = 1, 9 do + local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_augment = true, random_half = true}) + image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1}) + image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1}) + print(y:size(), x:size()) + --print(x:mean(), y:mean()) + end + for i = 1, 9 do + local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_augment = true, random_half = true}) + image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1}) + image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1}) + print(y:size(), x:size()) --print(x:mean(), y:mean()) end end --test_jpeg() --test_scale() +--test_jpeg_scale() return pairwise_transform diff --git a/lib/portable.lua b/lib/portable.lua new file mode 100644 index 0000000..d7cc32b --- /dev/null +++ b/lib/portable.lua @@ -0,0 +1,15 @@ +local function load_cuda() + require 'cunn' +end + +if pcall(load_cuda) then + require 'cunn' +else + --[[ TODO: fakecuda does not work. + + io.stderr:write("use FakeCUDA; if you have NVIDIA GPU, Please install cutorch and cunn. FakeCuda will be extremely slow.\n") + require 'torch' + require 'nn' + require('fakecuda').init(true) + --]] +end diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 1b590a1..63357fa 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -1,7 +1,7 @@ require 'image' local iproc = require './iproc' -local function reconstruct_layer(model, x, block_size, offset) +local function reconstruct_layer(model, x, offset, block_size) if x:dim() == 2 then x = x:reshape(1, x:size(1), x:size(2)) end @@ -42,7 +42,7 @@ function reconstruct.image(model, x, offset, block_size) local pad_h2 = (h - offset) - x:size(2) local pad_w2 = (w - offset) - x:size(3) local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)) - local y = reconstruct_layer(model, yuv[1], block_size, offset) + local y = reconstruct_layer(model, yuv[1], offset, block_size) y[torch.lt(y, 0)] = 0 y[torch.gt(y, 1)] = 1 yuv[1]:copy(y) @@ -74,7 +74,7 @@ function reconstruct.scale(model, scale, x, offset, block_size) local pad_w2 = (w - offset) - x:size(3) local yuv_nn = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)) local yuv_jinc = image.rgb2yuv(iproc.padding(x_jinc, pad_w1, pad_w2, pad_h1, pad_h2)) - local y = reconstruct_layer(model, yuv_nn[1], block_size, offset) + local y = reconstruct_layer(model, yuv_nn[1], offset, block_size) y[torch.lt(y, 0)] = 0 y[torch.gt(y, 1)] = 1 yuv_jinc[1]:copy(y) diff --git a/lib/settings.lua b/lib/settings.lua index 2d6522c..ba146f0 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -1,5 +1,3 @@ -require 'torch' -require 'cutorch' require 'xlua' require 'pl' @@ -22,10 +20,11 @@ cmd:option("-seed", 11, 'fixed input seed') cmd:option("-data_dir", "./data", 'data directory') cmd:option("-test", "images/miku_small.png", 'test image file') cmd:option("-model_dir", "./models", 'model directory') -cmd:option("-method", "scale", '(noise|scale)') +cmd:option("-method", "scale", '(noise|scale|noise_scale)') cmd:option("-noise_level", 1, '(1|2)') cmd:option("-scale", 2.0, 'scale') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') +cmd:option("-random_half", 1, 'enable data augmentation using half resolution image') cmd:option("-crop_size", 128, 'crop size') cmd:option("-batch_size", 2, 'mini batch size') cmd:option("-epoch", 200, 'epoch') @@ -36,16 +35,25 @@ for k, v in pairs(opt) do settings[k] = v end if settings.method == "noise" then - settings.model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level) + settings.model_file = string.format("%s/noise%d_model.t7", + settings.model_dir, settings.noise_level) elseif settings.method == "scale" then - settings.model_file = string.format("%s/scale%.1fx_model.t7", settings.model_dir, settings.scale) - settings.denoise_model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level) + settings.model_file = string.format("%s/scale%.1fx_model.t7", + settings.model_dir, settings.scale) +elseif settings.method == "noise_scale" then + settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7", + settings.model_dir, settings.noise_level, settings.scale) else error("unknown method: " .. settings.method) end if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then error("scale must be mod-2") end +if settings.random_half == 1 then + settings.random_half = true +else + settings.random_half = false +end torch.setnumthreads(settings.core) settings.images = string.format("%s/images.t7", settings.data_dir) @@ -53,6 +61,14 @@ settings.image_list = string.format("%s/image_list.txt", settings.data_dir) settings.validation_ratio = 0.1 settings.validation_crops = 40 -settings.block_offset = 7 -- see srcnn.lua + +local srcnn = require './srcnn' +if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then + settings.create_model = srcnn.waifu4x + settings.block_offset = 13 +else + settings.create_model = srcnn.waifu2x + settings.block_offset = 7 +end return settings diff --git a/lib/srcnn.lua b/lib/srcnn.lua index c7d4eb3..f4a5dd8 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -1,32 +1,53 @@ -require 'cunn' -require 'cudnn' require './LeakyReLU' -function cudnn.SpatialConvolution:reset(stdv) +function nn.SpatialConvolutionMM:reset(stdv) stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane)) self.weight:normal(0, stdv) self.bias:fill(0) end -local function create_model() - local model = nn.Sequential() +local srcnn = {} +function srcnn.waifu2x() + local model = nn.Sequential() - model:add(cudnn.SpatialConvolution(1, 32, 3, 3, 1, 1, 0, 0):fastest()) - model:add(nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.SpatialConvolutionMM(1, 32, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) - model:add(cudnn.SpatialConvolution(128, 1, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(128, 1, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) --model:cuda() --print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size()) return model, 7 end -return create_model + +-- current 4x is worse then 2x * 2 +function srcnn.waifu4x() + local model = nn.Sequential() + + model:add(nn.SpatialConvolutionMM(1, 32, 9, 9, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + model:add(nn.SpatialConvolutionMM(128, 1, 5, 5, 1, 1, 0, 0)) + model:add(nn.View(-1):setNumInputDims(3)) + + return model, 13 +end +return srcnn diff --git a/models/anime_style_art/noise1_model.json b/models/anime_style_art/noise1_model.json new file mode 100644 index 0000000..7faf17a Binary files /dev/null and b/models/anime_style_art/noise1_model.json differ diff --git a/models/noise1_model.t7 b/models/anime_style_art/noise1_model.t7 similarity index 51% rename from models/noise1_model.t7 rename to models/anime_style_art/noise1_model.t7 index 63731ad..a3e6597 100644 Binary files a/models/noise1_model.t7 and b/models/anime_style_art/noise1_model.t7 differ diff --git a/models/anime_style_art/noise2_model.json b/models/anime_style_art/noise2_model.json new file mode 100644 index 0000000..0d9a29b Binary files /dev/null and b/models/anime_style_art/noise2_model.json differ diff --git a/models/scale2.0x_model.t7 b/models/anime_style_art/noise2_model.t7 similarity index 50% rename from models/scale2.0x_model.t7 rename to models/anime_style_art/noise2_model.t7 index 1f0f97d..51ff16c 100644 Binary files a/models/scale2.0x_model.t7 and b/models/anime_style_art/noise2_model.t7 differ diff --git a/models/anime_style_art/scale2.0x_model.json b/models/anime_style_art/scale2.0x_model.json new file mode 100644 index 0000000..ea13ac4 Binary files /dev/null and b/models/anime_style_art/scale2.0x_model.json differ diff --git a/models/noise2_model.t7 b/models/anime_style_art/scale2.0x_model.t7 similarity index 51% rename from models/noise2_model.t7 rename to models/anime_style_art/scale2.0x_model.t7 index 8deb555..c7e29ed 100644 Binary files a/models/noise2_model.t7 and b/models/anime_style_art/scale2.0x_model.t7 differ diff --git a/train.lua b/train.lua index 4560418..59b0c97 100644 --- a/train.lua +++ b/train.lua @@ -1,5 +1,4 @@ -require 'cutorch' -require 'cunn' +require './lib/portable' require 'optim' require 'xlua' require 'pl' @@ -7,7 +6,6 @@ require 'pl' local settings = require './lib/settings' local minibatch_adam = require './lib/minibatch_adam' local iproc = require './lib/iproc' -local create_model = require './lib/srcnn' local reconstruct = require './lib/reconstruct' local pairwise_transform = require './lib/pairwise_transform' local image_loader = require './lib/image_loader' @@ -61,10 +59,11 @@ local function validate(model, criterion, data) end local function train() - local model, offset = create_model() + local model, offset = settings.create_model() assert(offset == settings.block_offset) local criterion = nn.MSECriterion():cuda() local x = torch.load(settings.images) + local lrd_count = 0 local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x), settings.validation_crops) @@ -78,16 +77,23 @@ local function train() if settings.method == "scale" then return pairwise_transform.scale(x, settings.scale, - settings.crop_size, - offset, - {color_augment = not is_validation, - noise = false, - denoise_model = nil - }) + settings.crop_size, offset, + { color_augment = not is_validation, + random_half = settings.random_half}) elseif settings.method == "noise" then - return pairwise_transform.jpeg(x, settings.noise_level, + return pairwise_transform.jpeg(x, + settings.noise_level, settings.crop_size, offset, - not is_validation) + { color_augment = not is_validation, + random_half = settings.random_half}) + elseif settings.method == "noise_scale" then + return pairwise_transform.jpeg_scale(x, + settings.scale, + settings.noise_level, + settings.crop_size, offset, + { color_augment = not is_validation, + random_half = settings.random_half + }) end end local best_score = 100000.0 @@ -106,27 +112,38 @@ local function train() {1, settings.crop_size, settings.crop_size}, {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2} )) - if epoch % 1 == 0 then - collectgarbage() - model:evaluate() - print("# validation") - local score = validate(model, criterion, valid_xy) - if score < best_score then - best_score = score - print("* update best model") - torch.save(settings.model_file, model) - if settings.method == "noise" then - local log = path.join(settings.model_dir, - ("noise%d_best.png"):format(settings.noise_level)) - save_test_jpeg(model, test, log) - elseif settings.method == "scale" then - local log = path.join(settings.model_dir, - ("scale%.1f_best.png"):format(settings.scale)) - save_test_scale(model, test, log) - end + model:evaluate() + print("# validation") + local score = validate(model, criterion, valid_xy) + if score < best_score then + lrd_count = 0 + best_score = score + print("* update best model") + torch.save(settings.model_file, model) + if settings.method == "noise" then + local log = path.join(settings.model_dir, + ("noise%d_best.png"):format(settings.noise_level)) + save_test_jpeg(model, test, log) + elseif settings.method == "scale" then + local log = path.join(settings.model_dir, + ("scale%.1f_best.png"):format(settings.scale)) + save_test_scale(model, test, log) + elseif settings.method == "noise_scale" then + local log = path.join(settings.model_dir, + ("noise%d_scale%.1f_best.png"):format(settings.noise_level, + settings.scale)) + save_test_scale(model, test, log) + end + else + lrd_count = lrd_count + 1 + if lrd_count > 5 then + lrd_count = 0 + adam_config.learningRate = adam_config.learningRate * 0.8 + print("* learning rate decay: " .. adam_config.learningRate) end - print("current: " .. score .. ", best: " .. best_score) end + print("current: " .. score .. ", best: " .. best_score) + collectgarbage() end end torch.manualSeed(settings.seed) diff --git a/train.sh b/train.sh index 49cb2f1..34a3269 100755 --- a/train.sh +++ b/train.sh @@ -1,10 +1,10 @@ #!/bin/sh -th train.lua -method noise -noise_level 1 -test images/miku_noisy.png -th cleanup_model.lua -model models/noise1_model.t7 -oformat ascii +th train.lua -method noise -noise_level 1 -model_dir models/anime_style_art -test images/miku_noisy.png +th cleanup_model.lua -model models/anime_style_art/noise1_model.t7 -oformat ascii -th train.lua -method noise -noise_level 2 -test images/miku_noisy.png -th cleanup_model.lua -model models/noise2_model.t7 -oformat ascii +th train.lua -method noise -noise_level 2 -model_dir models/anime_style_art -test images/miku_noisy.png +th cleanup_model.lua -model models/anime_style_art/noise2_model.t7 -oformat ascii -th train.lua -method scale -scale 2 -test images/miku_small.png -th cleanup_model.lua -model models/scale2.0x_model.t7 -oformat ascii +th train.lua -method scale -scale 2 -model_dir models/anime_style_art -test images/miku_small.png +th cleanup_model.lua -model models/anime_style_art/scale2.0x_model.t7 -oformat ascii diff --git a/waifu2x.lua b/waifu2x.lua index 1764c5a..da04feb 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -1,4 +1,4 @@ -require 'cudnn' +require './lib/portable' require 'sys' require 'pl' require './lib/LeakyReLU' @@ -24,18 +24,18 @@ local function convert_image(opt) if opt.m == "noise" then local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii") model:evaluate() - new_x = reconstruct.image(model, x, BLOCK_OFFSET) + new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size) elseif opt.m == "scale" then local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") model:evaluate() - new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET) + new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) elseif opt.m == "noise_scale" then local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii") local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") noise_model:evaluate() scale_model:evaluate() x = reconstruct.image(noise_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET) + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) else error("undefined method:" .. opt.method) end @@ -63,17 +63,17 @@ local function convert_frames(opt) local x = image_loader.load_float(lines[i]) local new_x = nil if opt.m == "noise" and opt.noise_level == 1 then - new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET) + new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size) elseif opt.m == "noise" and opt.noise_level == 2 then new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) elseif opt.m == "scale" then - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET) + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) elseif opt.m == "noise_scale" and opt.noise_level == 1 then x = reconstruct.image(noise1_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET) + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) elseif opt.m == "noise_scale" and opt.noise_level == 2 then x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET) + new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) else error("undefined method:" .. opt.method) end @@ -106,7 +106,7 @@ local function waifu2x() cmd:option("-l", "", 'path of the image-list') cmd:option("-scale", 2, 'scale factor') cmd:option("-o", "(auto)", 'path of the output file') - cmd:option("-model_dir", "./models", 'model directory') + cmd:option("-model_dir", "./models/anime_style_art", 'model directory') cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)') cmd:option("-noise_level", 1, '(1|2)') cmd:option("-crop_size", 128, 'patch size per process') diff --git a/web.lua b/web.lua index 483e01a..8f4220b 100644 --- a/web.lua +++ b/web.lua @@ -1,37 +1,34 @@ -local ROOT = '/home/ubuntu/waifu2x' - -_G.TURBO_SSL = true -- Enable SSL local turbo = require 'turbo' local uuid = require 'uuid' local ffi = require 'ffi' local md5 = require 'md5' -require 'torch' -require 'cudnn' require 'pl' torch.setdefaulttensortype('torch.FloatTensor') torch.setnumthreads(4) -package.path = package.path .. ";" .. path.join(ROOT, 'lib', '?.lua') +require './lib/portable' +require './lib/LeakyReLU' -require 'LeakyReLU' -local iproc = require 'iproc' -local reconstruct = require 'reconstruct' -local image_loader = require 'image_loader' +local iproc = require './lib/iproc' +local reconstruct = require './lib/reconstruct' +local image_loader = require './lib/image_loader' -local noise1_model = torch.load(path.join(ROOT, "models", "noise1_model.t7"), "ascii") -local noise2_model = torch.load(path.join(ROOT, "models", "noise2_model.t7"), "ascii") -local scale20_model = torch.load(path.join(ROOT, "models", "scale2.0x_model.t7"), "ascii") +local MODEL_DIR = "./models/anime_style_art" + +local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii") +local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii") +local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii") local USE_CACHE = true -local CACHE_DIR = path.join(ROOT, "cache") +local CACHE_DIR = "./cache" local MAX_NOISE_IMAGE = 2560 * 2560 local MAX_SCALE_IMAGE = 1280 * 1280 local CURL_OPTIONS = { - request_timeout = 10, - connect_timeout = 5, + request_timeout = 15, + connect_timeout = 10, allow_redirects = true, - max_redirects = 1 + max_redirects = 2 } local CURL_MAX_SIZE = 2 * 1024 * 1024 local BLOCK_OFFSET = 7 -- see srcnn.lua @@ -171,8 +168,8 @@ function APIHandler:post() collectgarbage() end local FormHandler = class("FormHandler", turbo.web.RequestHandler) -local index_ja = file.read(path.join(ROOT, "assets/index.ja.html")) -local index_en = file.read(path.join(ROOT, "assets/index.html")) +local index_ja = file.read("./assets/index.ja.html") +local index_en = file.read("./assets/index.html") function FormHandler:get() local lang = self.request.headers:get("Accept-Language") if lang then @@ -193,8 +190,8 @@ end local app = turbo.web.Application:new( { {"^/$", FormHandler}, - {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")}, - {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")}, + {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")}, + {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")}, {"^/api$", APIHandler}, } )