commit 1273b3609ed569a292dfaffd7e5ced8ad0879714 Author: nagadomi Date: Sat May 16 14:48:05 2015 +0900 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7cbde17 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*~ +cache/*.png diff --git a/assets/index.html b/assets/index.html new file mode 100644 index 0000000..bf8e893 --- /dev/null +++ b/assets/index.html @@ -0,0 +1,78 @@ + + + + + + + + +

waifu2x

+
+ + Fork me on GitHub + + ja/en +
+
+ Single-Image Super-Resolution for anime/fan-arts using Deep Convolutional Neural Networks. +
+
+
+ Image +
+ URL: or +
+
+ FILE: +
+
+ Limits: FileSize: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px +
+
+
+ Noise Reduction (expect JPEG Artifact) + + + +
+
+ Upscaling + + + +
+ +
+ + diff --git a/assets/index.ja.html b/assets/index.ja.html new file mode 100644 index 0000000..28052e2 --- /dev/null +++ b/assets/index.ja.html @@ -0,0 +1,85 @@ + + + + + + + + + +

waifu2x

+
+ + Fork me on GitHub + +
+
+ 深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. +
+
+
+ Image +
+ URL: or +
+
+ FILE: +
+
+ 制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px +
+
+
+ ノイズ除去 (JPEGノイズを想定) + + + +
+
+ 拡大 + + + +
+ +
+
+ +
+ + diff --git a/cache/.gitkeep b/cache/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/cleanup_model.lua b/cleanup_model.lua new file mode 100644 index 0000000..abbaac2 --- /dev/null +++ b/cleanup_model.lua @@ -0,0 +1,69 @@ +require 'cunn' +require 'cudnn' +require './lib/LeakyReLU' + +torch.setdefaulttensortype("torch.FloatTensor") + +-- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049 +local function zeroDataSize(data) + if type(data) == 'table' then + for i = 1, #data do + data[i] = zeroDataSize(data[i]) + end + elseif type(data) == 'userdata' then + data = torch.Tensor():typeAs(data) + end + return data +end + +-- Resize the output, gradInput, etc temporary tensors to zero (so that the +-- on disk size is smaller) +local function cleanupModel(node) + if node.output ~= nil then + node.output = zeroDataSize(node.output) + end + if node.gradInput ~= nil then + node.gradInput = zeroDataSize(node.gradInput) + end + if node.finput ~= nil then + node.finput = zeroDataSize(node.finput) + end + if tostring(node) == "nn.LeakyReLU" then + if node.negative ~= nil then + node.negative = zeroDataSize(node.negative) + end + end + if tostring(node) == "nn.Dropout" then + if node.noise ~= nil then + node.noise = zeroDataSize(node.noise) + end + end + -- Recurse on nodes with 'modules' + if (node.modules ~= nil) then + if (type(node.modules) == 'table') then + for i = 1, #node.modules do + local child = node.modules[i] + cleanupModel(child) + end + end + end + + collectgarbage() +end + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("cleanup model") +cmd:text("Options:") +cmd:option("-model", "./model.t7", 'path of model file') +cmd:option("-iformat", "binary", 'input format') +cmd:option("-oformat", "binary", 'output format') + +local opt = cmd:parse(arg) +local model = torch.load(opt.model, opt.iformat) +if model then + cleanupModel(model) + torch.save(opt.model, model, opt.oformat) +else + error("model not found") +end diff --git a/lib/LeakyReLU.lua b/lib/LeakyReLU.lua new file mode 100644 index 0000000..09b4f81 --- /dev/null +++ b/lib/LeakyReLU.lua @@ -0,0 +1,30 @@ +if nn.LeakyReLU then + return +end +local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module') + +function LeakyReLU:__init(negative_scale) + parent.__init(self) + self.negative_scale = negative_scale or 0.333 + self.negative = torch.Tensor() +end + +function LeakyReLU:updateOutput(input) + self.output:resizeAs(input):copy(input):abs():add(input):div(2) + self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale) + self.output:add(self.negative) + + return self.output +end + +function LeakyReLU:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(gradOutput) + -- filter positive + self.negative:sign():add(1) + torch.cmul(self.gradInput, gradOutput, self.negative) + -- filter negative + self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput) + self.gradInput:add(self.negative) + + return self.gradInput +end diff --git a/lib/image_loader.lua b/lib/image_loader.lua new file mode 100644 index 0000000..0b1009d --- /dev/null +++ b/lib/image_loader.lua @@ -0,0 +1,73 @@ +local gm = require 'graphicsmagick' +require 'pl' + +local image_loader = {} + +function image_loader.decode_float(blob) + local im = image_loader.decode_byte(blob) + if im then + im = im:float():div(255) + end + return im +end +function image_loader.encode_png(tensor) + local im = gm.Image(tensor, "RGB", "DHW") + im:format("png") + return im:toBlob() +end +function image_loader.decode_byte(blob) + local load_image = function() + local im = gm.Image() + im:fromBlob(blob, #blob) + -- FIXME: How to detect that a image has an alpha channel? + if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then + -- merge alpha channel + im = im:toTensor('float', 'RGBA', 'DHW') + local w2 = im[4] + local w1 = im[4] * -1 + 1 + local new_im = torch.FloatTensor(3, im:size(2), im:size(3)) + -- apply the white background + new_im[1]:copy(im[1]):cmul(w2):add(w1) + new_im[2]:copy(im[2]):cmul(w2):add(w1) + new_im[3]:copy(im[3]):cmul(w2):add(w1) + im = new_im:mul(255):byte() + else + im = im:toTensor('byte', 'RGB', 'DHW') + end + return im + end + local state, ret = pcall(load_image) + if state then + return ret + else + return nil + end +end +function image_loader.load_float(file) + local fp = io.open(file, "rb") + local buff = fp:read("*a") + fp:close() + return image_loader.decode_float(buff) +end +function image_loader.load_byte(file) + local fp = io.open(file, "rb") + local buff = fp:read("*a") + fp:close() + return image_loader.decode_byte(buff) +end +local function test() + require 'image' + local img + img = image_loader.load_float("./a.jpg") + if img then + print(img:min()) + print(img:max()) + image.display(img) + end + img = image_loader.load_float("./b.png") + if img then + image.display(img) + end +end +--test() +return image_loader diff --git a/lib/iproc.lua b/lib/iproc.lua new file mode 100644 index 0000000..d93f33e --- /dev/null +++ b/lib/iproc.lua @@ -0,0 +1,35 @@ +local gm = require 'graphicsmagick' +local image = require 'image' +local iproc = {} + +function iproc.sample(src, width, height) + local t = "float" + if src:type() == "torch.ByteTensor" then + t = "byte" + end + local im = gm.Image(src, "RGB", "DHW") + im:sample(math.ceil(width), math.ceil(height)) + return im:toTensor(t, "RGB", "DHW") +end +function iproc.scale(src, width, height, filter) + local t = "float" + if src:type() == "torch.ByteTensor" then + t = "byte" + end + filter = filter or "Box" + local im = gm.Image(src, "RGB", "DHW") + im:size(math.ceil(width), math.ceil(height), filter) + return im:toTensor(t, "RGB", "DHW") +end +function iproc.padding(img, w1, w2, h1, h2) + local dst_height = img:size(2) + h1 + h2 + local dst_width = img:size(3) + w1 + w2 + local flow = torch.Tensor(2, dst_height, dst_width) + flow[1] = torch.ger(torch.linspace(0, dst_height -1, dst_height), torch.ones(dst_width)) + flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width)) + flow[1]:add(-h1) + flow[2]:add(-w1) + return image.warp(img, flow, "simple", false, "clamp") +end + +return iproc diff --git a/lib/minibatch_sgd.lua b/lib/minibatch_sgd.lua new file mode 100644 index 0000000..b5dad8f --- /dev/null +++ b/lib/minibatch_sgd.lua @@ -0,0 +1,63 @@ +require 'optim' +require 'cutorch' +require 'xlua' + +local function minibatch_sgd(model, criterion, + train_x, + config, transformer, + input_size, target_size) + local parameters, gradParameters = model:getParameters() + config = config or {} + local sum_loss = 0 + local count_loss = 0 + local batch_size = config.xBatchSize or 32 + local shuffle = torch.randperm(#train_x) + local c = 1 + local inputs = torch.Tensor(batch_size, + input_size[1], input_size[2], input_size[3]):cuda() + local targets = torch.Tensor(batch_size, + target_size[1] * target_size[2] * target_size[3]):cuda() + local inputs_tmp = torch.Tensor(batch_size, + input_size[1], input_size[2], input_size[3]) + local targets_tmp = torch.Tensor(batch_size, + target_size[1] * target_size[2] * target_size[3]) + + for t = 1, #train_x, batch_size do + if t + batch_size > #train_x then + break + end + xlua.progress(t, #train_x) + for i = 1, batch_size do + local x, y = transformer(train_x[shuffle[t + i - 1]]) + inputs_tmp[i]:copy(x) + targets_tmp[i]:copy(y) + end + inputs:copy(inputs_tmp) + targets:copy(targets_tmp) + + local feval = function(x) + if x ~= parameters then + parameters:copy(x) + end + gradParameters:zero() + local output = model:forward(inputs) + local f = criterion:forward(output, targets) + sum_loss = sum_loss + f + count_loss = count_loss + 1 + model:backward(inputs, criterion:backward(output, targets)) + return f, gradParameters + end + -- must use Adam!! + optim.adam(feval, parameters, config) + + c = c + 1 + if c % 10 == 0 then + collectgarbage() + end + end + xlua.progress(#train_x, #train_x) + + return { mse = sum_loss / count_loss} +end + +return minibatch_sgd diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua new file mode 100644 index 0000000..24de37f --- /dev/null +++ b/lib/pairwise_transform.lua @@ -0,0 +1,174 @@ +require 'image' +local gm = require 'graphicsmagick' +local iproc = require './iproc' +local reconstract = require './reconstract' +local pairwise_transform = {} + +function pairwise_transform.scale(src, scale, size, offset, options) + options = options or {} + local yi = torch.radom(0, src:size(2) - size - 1) + local xi = torch.random(0, src:size(3) - size - 1) + local down_scale = 1.0 / scale + local y = image.crop(src, xi, yi, xi + size, yi + size) + local flip = torch.random(1, 4) + local nega = torch.random(0, 1) + local filters = { + "Box", -- 0.012756949974688 + "Blackman", -- 0.013191924552285 + --"Cartom", -- 0.013753536746706 + --"Hanning", -- 0.013761314529647 + --"Hermite", -- 0.013850225205266 + --"SincFast", -- 0.014095824314306 + --"Jinc", -- 0.014244299255442 + } + local downscale_filter = filters[torch.random(1, #filters)] + + if r == 1 then + y = image.hflip(y) + elseif r == 2 then + y = image.vflip(y) + elseif r == 3 then + y = image.hflip(image.vflip(y)) + elseif r == 4 then + -- none + end + if options.color_augment then + y = y:float():div(255) + local color_scale = torch.Tensor(3):uniform(0.8, 1.2) + for i = 1, 3 do + y[i]:mul(color_scale[i]) + end + y[torch.lt(y, 0)] = 0 + y[torch.gt(y, 1.0)] = 1.0 + y = y:mul(255):byte() + end + local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) + if options.noise and (options.noise_ratio or 0.5) > torch.uniform() then + -- add noise + local quality = {torch.random(70, 90)} + for i = 1, #quality do + x = gm.Image(x, "RGB", "DHW") + x:format("jpeg") + local blob, len = x:toBlob(quality[i]) + x:fromBlob(blob, len) + x = x:toTensor("byte", "RGB", "DHW") + end + end + if options.denoise_model and (options.denoise_ratio or 0.5) > torch.uniform() then + x = reconstract(options.denoise_model, x:float():div(255), offset):mul(255):byte() + end + x = iproc.scale(x, y:size(3), y:size(2)) + y = y:float():div(255) + x = x:float():div(255) + y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) + x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) + + return x, image.crop(y, offset, offset, size - offset, size - offset) +end +function pairwise_transform.jpeg_(src, quality, size, offset, color_augment) + if color_augment == nil then color_augment = true end + local yi = torch.random(0, src:size(2) - size - 1) + local xi = torch.random(0, src:size(3) - size - 1) + local y = src + local x + local flip = torch.random(1, 4) + + if color_augment then + local color_scale = torch.Tensor(3):uniform(0.8, 1.2) + y = y:float():div(255) + for i = 1, 3 do + y[i]:mul(color_scale[i]) + end + y[torch.lt(y, 0)] = 0 + y[torch.gt(y, 1.0)] = 1.0 + y = y:mul(255):byte() + end + x = y + for i = 1, #quality do + x = gm.Image(x, "RGB", "DHW") + x:format("jpeg") + local blob, len = x:toBlob(quality[i]) + x:fromBlob(blob, len) + x = x:toTensor("byte", "RGB", "DHW") + end + + y = image.crop(y, xi, yi, xi + size, yi + size) + x = image.crop(x, xi, yi, xi + size, yi + size) + x = x:float():div(255) + y = y:float():div(255) + + if flip == 1 then + y = image.hflip(y) + x = image.hflip(x) + elseif flip == 2 then + y = image.vflip(y) + x = image.vflip(x) + elseif flip == 3 then + y = image.hflip(image.vflip(y)) + x = image.hflip(image.vflip(x)) + elseif flip == 4 then + -- none + end + y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) + x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) + + return x, image.crop(y, offset, offset, size - offset, size - offset) +end +function pairwise_transform.jpeg(src, level, size, offset, color_augment) + if level == 1 then + return pairwise_transform.jpeg_(src, {torch.random(65, 85)}, + size, offset, + color_augment) + elseif level == 2 then + local r = torch.uniform() + if r > 0.6 then + return pairwise_transform.jpeg_(src, {torch.random(27, 80)}, + size, offset, + color_augment) + elseif r > 0.3 then + local quality1 = torch.random(32, 40) + local quality2 = quality1 - 5 + return pairwise_transform.jpeg_(src, {quality1, quality2}, + size, offset, + color_augment) + else + local quality1 = torch.random(47, 70) + return pairwise_transform.jpeg_(src, {quality1, quality1 - 10, quality1 - 20}, + size, offset, + color_augment) + end + else + error("unknown noise level: " .. level) + end +end + +local function test_jpeg() + local loader = require 'image_loader' + local src = loader.load_byte("a.jpg") + + for i = 2, 9 do + local y, x = pairwise_transform.jpeg_(src, {i * 10}, 128, 0, false) + image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0}) + image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0}) + --print(x:mean(), y:mean()) + end +end +local function test_scale() + require 'nn' + require 'cudnn' + require './LeakyReLU' + + local loader = require 'image_loader' + local src = loader.load_byte("e.jpg") + + for i = 1, 9 do + local y, x = pairwise_transform.scale(src, 2.0, "Box", 128, 7, {noise = true, denoise_model = torch.load("models/noise1_model.t7")}) + image.display({image = y, legend = "y:" .. (i * 10)}) + image.display({image = x, legend = "x:" .. (i * 10)}) + --print(x:mean(), y:mean()) + end +end +--test_jpeg() +--test_scale() + +return pairwise_transform diff --git a/lib/reconstract.lua b/lib/reconstract.lua new file mode 100644 index 0000000..7e4e4c0 --- /dev/null +++ b/lib/reconstract.lua @@ -0,0 +1,58 @@ +require 'image' +local iproc = require './iproc' + +local function reconstract_layer(model, x, block_size, offset) + if x:dim() == 2 then + x = x:reshape(1, x:size(1), x:size(2)) + end + local new_x = torch.Tensor():resizeAs(x):zero() + local output_size = block_size - offset * 2 + local input = torch.CudaTensor(1, 1, block_size, block_size) + + for i = 1, x:size(2), output_size do + for j = 1, x:size(3), output_size do + if i + block_size - 1 <= x:size(2) and j + block_size - 1 <= x:size(3) then + local index = {{}, + {i, i + block_size - 1}, + {j, j + block_size - 1}} + input:copy(x[index]) + local output = model:forward(input):float():view(1, output_size, output_size) + local output_index = {{}, + {i + offset, offset + i + output_size - 1}, + {offset + j, offset + j + output_size - 1}} + new_x[output_index]:copy(output) + end + end + end + return new_x +end +local function reconstract(model, x, offset, block_size) + block_size = block_size or 128 + local output_size = block_size - offset * 2 + local h_blocks = math.floor(x:size(2) / output_size) + + ((x:size(2) % output_size == 0 and 0) or 1) + local w_blocks = math.floor(x:size(3) / output_size) + + ((x:size(3) % output_size == 0 and 0) or 1) + + local h = offset + h_blocks * output_size + offset + local w = offset + w_blocks * output_size + offset + local pad_h1 = offset + local pad_w1 = offset + local pad_h2 = (h - offset) - x:size(2) + local pad_w2 = (w - offset) - x:size(3) + local yuv = image.rgb2yuv(iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)) + local y = reconstract_layer(model, yuv[1], block_size, offset) + y[torch.lt(y, 0)] = 0 + y[torch.gt(y, 1)] = 1 + yuv[1]:copy(y) + local output = image.yuv2rgb(image.crop(yuv, + pad_w1, pad_h1, + yuv:size(3) - pad_w2, yuv:size(2) - pad_h2)) + output[torch.lt(output, 0)] = 0 + output[torch.gt(output, 1)] = 1 + collectgarbage() + + return output +end + +return reconstract diff --git a/lib/settings.lua b/lib/settings.lua new file mode 100644 index 0000000..1678236 --- /dev/null +++ b/lib/settings.lua @@ -0,0 +1,58 @@ +require 'torch' +require 'cutorch' +require 'xlua' +require 'pl' + +-- global settings + +if package.preload.settings then + return package.preload.settings +end + +-- default tensor type +torch.setdefaulttensortype('torch.FloatTensor') + +local settings = {} + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x") +cmd:text("Options:") +cmd:option("-seed", 11, 'fixed input seed') +cmd:option("-data_dir", "./data", 'data directory') +cmd:option("-test", "images/miku_small.png", 'test image file') +cmd:option("-model_dir", "./models", 'model directory') +cmd:option("-method", "scale", '(noise|scale)') +cmd:option("-noise_level", 1, '(1|2)') +cmd:option("-scale", 2.0, 'scale') +cmd:option("-learning_rate", 0.00025, 'learning rate for adam') +cmd:option("-crop_size", 128, 'crop size') +cmd:option("-batch_size", 2, 'mini batch size') +cmd:option("-epoch", 200, 'epoch') +cmd:option("-core", 2, 'cpu core') + +local opt = cmd:parse(arg) +for k, v in pairs(opt) do + settings[k] = v +end +if settings.method == "noise" then + settings.model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level) +elseif settings.method == "scale" then + settings.model_file = string.format("%s/scale%.1fx_model.t7", settings.model_dir, settings.scale) + settings.denoise_model_file = string.format("%s/noise%d_model.t7", settings.model_dir, settings.noise_level) +else + error("unknown method: " .. settings.method) +end +if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then + error("scale must be mod-2") +end +torch.setnumthreads(settings.core) + +settings.images = string.format("%s/images.t7", settings.data_dir) +settings.image_list = string.format("%s/image_list.txt", settings.data_dir) + +settings.validation_ratio = 01 +settings.validation_crops = 40 +settings.block_offset = 7 -- see srcnn.lua + +return settings diff --git a/lib/srcnn.lua b/lib/srcnn.lua new file mode 100644 index 0000000..c7d4eb3 --- /dev/null +++ b/lib/srcnn.lua @@ -0,0 +1,32 @@ +require 'cunn' +require 'cudnn' +require './LeakyReLU' + +function cudnn.SpatialConvolution:reset(stdv) + stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane)) + self.weight:normal(0, stdv) + self.bias:fill(0) +end +local function create_model() + local model = nn.Sequential() + + model:add(cudnn.SpatialConvolution(1, 32, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(128, 1, 3, 3, 1, 1, 0, 0):fastest()) + model:add(nn.View(-1):setNumInputDims(3)) +--model:cuda() +--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size()) + + return model, 7 +end +return create_model diff --git a/models/noise1_model.t7 b/models/noise1_model.t7 new file mode 100644 index 0000000..70e6dfc Binary files /dev/null and b/models/noise1_model.t7 differ diff --git a/models/noise2_model.t7 b/models/noise2_model.t7 new file mode 100644 index 0000000..8deb555 Binary files /dev/null and b/models/noise2_model.t7 differ diff --git a/models/scale2.0x_model.t7 b/models/scale2.0x_model.t7 new file mode 100644 index 0000000..1f0f97d Binary files /dev/null and b/models/scale2.0x_model.t7 differ diff --git a/train.lua b/train.lua new file mode 100644 index 0000000..120119d --- /dev/null +++ b/train.lua @@ -0,0 +1,143 @@ +require 'cutorch' +require 'cunn' +require 'optim' +require 'xlua' +require 'pl' + +local settings = require './lib/settings' +local minibatch_sgd = require './lib/minibatch_sgd' +local iproc = require './lib/iproc' +local create_model = require './lib/srcnn' +local reconstract, reconstract_ch = require './lib/reconstract' +local pairwise_transform = require './lib/pairwise_transform' +local image_loader = require './lib/image_loader' + +local function save_test_scale(model, rgb, file) + local input = iproc.scale(rgb, + rgb:size(3) * settings.scale, + rgb:size(2) * settings.scale) + local up = reconstract(model, input, settings.block_offset) + + image.save(file, up) +end +local function save_test_jpeg(model, rgb, file) + local im, count = reconstract(model, rgb, settings.block_offset) + image.save(file, im) +end +local function split_data(x, test_size) + local index = torch.randperm(#x) + local train_size = #x - test_size + local train_x = {} + local valid_x = {} + for i = 1, train_size do + train_x[i] = x[index[i]] + end + for i = 1, test_size do + valid_x[i] = x[index[train_size + i]] + end + return train_x, valid_x +end +local function make_validation_set(x, transformer, n) + n = n or 4 + local data = {} + for i = 1, #x do + for k = 1, n do + local x, y = transformer(x[i], true) + table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)), + y = y:reshape(1, y:size(1), y:size(2), y:size(3))}) + end + xlua.progress(i, #x) + collectgarbage() + end + return data +end +local function validate(model, criterion, data) + local loss = 0 + for i = 1, #data do + local z = model:forward(data[i].x:cuda()) + loss = loss + criterion:forward(z, data[i].y:cuda()) + xlua.progress(i, #data) + if i % 10 == 0 then + collectgarbage() + end + end + return loss / #data +end + +local function train() + local model, offset = create_model() + assert(offset == settings.block_offset) + local criterion = nn.MSECriterion():cuda() + local x = torch.load(settings.images) + local train_x, valid_x = split_data(x, + math.floor(settings.validation_ratio * #x), + settings.validation_crops) + local test = image_loader.load_float(settings.test) + local adam_config = { + learningRate = settings.learning_rate, + xBatchSize = settings.batch_size, + } + local denoise_model = nil + if settings.method == "scale" and path.exists(settings.denoise_model_file) then + denoise_model = torch.load(settings.denoise_model_file) + end + local transformer = function(x, is_validation) + if is_validation == nil then is_validation = false end + if settings.method == "scale" then + return pairwise_transform.scale(x, + settings.scale, + settings.crop_size, + offset, + {color_augment = not is_validation, + noise = false, + denoise_model = nil + }) + elseif settings.method == "noise" then + return pairwise_transform.jpeg(x, settings.noise_level, + settings.crop_size, offset, + not is_validation) + end + end + local best_score = 100000.0 + print("# make validation-set") + local valid_xy = make_validation_set(valid_x, transformer, 20) + valid_x = nil + + collectgarbage() + model:cuda() + print("load .. " .. #train_x) + for epoch = 1, settings.epoch do + model:training() + print("# " .. epoch) + print(minibatch_sgd(model, criterion, train_x, adam_config, + transformer, + {1, settings.crop_size, settings.crop_size}, + {1, settings.crop_size - offset * 2, settings.crop_size - offset * 2} + )) + if epoch % 1 == 0 then + collectgarbage() + model:evaluate() + print("# validation") + local score = validate(model, criterion, valid_xy) + if score < best_score then + best_score = score + print("* update best model") + torch.save(settings.model_file, model) + if settings.method == "noise" then + local log = path.join(settings.model_dir, + ("noise%d_best.png"):format(settings.noise_level)) + save_test_jpeg(model, test, log) + elseif settings.method == "scale" then + local log = path.join(settings.model_dir, + ("scale%.1f_best.png"):format(settings.scale)) + save_test_scale(model, test, log) + end + end + print("current: " .. score .. ", best: " .. best_score) + end + end +end +torch.manualSeed(settings.seed) +cutorch.manualSeed(settings.seed) +print(settings) +train() diff --git a/waifu2x.lua b/waifu2x.lua new file mode 100644 index 0000000..dedc1f5 --- /dev/null +++ b/waifu2x.lua @@ -0,0 +1,62 @@ +require 'cudnn' +require 'sys' +require 'pl' +require './lib/LeakyReLU' + +local iproc = require './lib/iproc' +local reconstract = require './lib/reconstract' +local image_loader = require './lib/image_loader' + +local BLOCK_OFFSET = 7 + +torch.setdefaulttensortype('torch.FloatTensor') + +local function waifu2x() + local cmd = torch.CmdLine() + cmd:text() + cmd:text("waifu2x") + cmd:text("Options:") + cmd:option("-i", "images/miku_small.png", 'path of input image') + cmd:option("-o", "(auto)", 'path of output') + cmd:option("-model_dir", "./models", 'model directory') + cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)') + cmd:option("-noise_level", 1, '(1|2)') + cmd:option("-crop_size", 128, 'crop size') + local opt = cmd:parse(arg) + if opt.o == "(auto)" then + local name = path.basename(opt.i) + local e = path.extension(name) + local base = name:sub(0, name:len() - e:len()) + opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m)) + end + + local x = image_loader.load_float(opt.i) + local new_x = nil + local t = sys.clock() + if opt.m == "noise" then + local model = torch.load(path.join(opt.model_dir, + ("noise%d_model.t7"):format(opt.noise_level)), "ascii") + model:evaluate() + new_x = reconstract(model, x, BLOCK_OFFSET) + elseif opt.m == "scale" then + local model = torch.load(path.join(opt.model_dir, "scale2.0x_model.t7"), "ascii") + model:evaluate() + x = iproc.scale(x, x:size(3) * 2.0, x:size(2) * 2.0) + new_x = reconstract(model, x, BLOCK_OFFSET) + elseif opt.m == "noise_scale" then + local noise_model = torch.load(path.join(opt.model_dir, + ("noise%d_model.t7"):format(opt.noise_level)), "ascii") + local scale_model = torch.load(path.join(opt.model_dir, "scale2.0x_model.t7"), "ascii") + + noise_model:evaluate() + scale_model:evaluate() + x = reconstract(noise_model, x, BLOCK_OFFSET) + x = iproc.scale(x, x:size(3) * 2.0, x:size(2) * 2.0) + new_x = reconstract(scale_model, x, BLOCK_OFFSET) + else + error("undefined method:" .. opt.method) + end + image.save(opt.o, new_x) + print(opt.o .. ": " .. (sys.clock() - t) .. " sec") +end +waifu2x() diff --git a/web.lua b/web.lua new file mode 100644 index 0000000..074e0af --- /dev/null +++ b/web.lua @@ -0,0 +1,201 @@ +local ROOT = '/home/nagadomi/dev/waifu2x' + +_G.TURBO_SSL = true -- Enable SSL +local turbo = require 'turbo' +local uuid = require 'uuid' +local ffi = require 'ffi' +local md5 = require 'md5' +require 'torch' +require 'cudnn' +require 'pl' + +torch.setdefaulttensortype('torch.FloatTensor') +torch.setnumthreads(4) + +package.path = package.path .. ";" .. path.join(ROOT, 'lib', '?.lua') + +require 'LeakyReLU' +local iproc = require 'iproc' +local reconstract = require 'reconstract' +local image_loader = require 'image_loader' + +local noise1_model = torch.load(path.join(ROOT, "models", "noise1_model.t7"), "ascii") +local noise2_model = torch.load(path.join(ROOT, "models", "noise2_model.t7"), "ascii") +local scale20_model = torch.load(path.join(ROOT, "models", "scale2.0x_model.t7"), "ascii") + +local USE_CACHE = true +local CACHE_DIR = path.join(ROOT, "cache") +local MAX_NOISE_IMAGE = 2560 * 2560 +local MAX_SCALE_IMAGE = 1280 * 1280 +local CURL_OPTIONS = { + request_timeout = 10, + connect_timeout = 5, + allow_redirects = true, + max_redirects = 1 +} +local CURL_MAX_SIZE = 2 * 1024 * 1024 +local BLOCK_OFFSET = 7 -- see srcnn.lua + +local function valid_size(x, scale) + if scale == 0 then + return x:size(2) * x:size(3) <= MAX_NOISE_IMAGE + else + return x:size(2) * x:size(3) <= MAX_SCALE_IMAGE + end +end + +local function get_image(req) + local file = req:get_argument("file", "") + local url = req:get_argument("url", "") + local blob = nil + local img = nil + + if file and file:len() > 0 then + blob = file + img = image_loader.decode_float(blob) + elseif url and url:len() > 0 then + local res = coroutine.yield( + turbo.async.HTTPClient({verify_ca=false}, + nil, + CURL_MAX_SIZE):fetch(url, CURL_OPTIONS) + ) + if res.code == 200 then + local content_type = res.headers:get("Content-Type", true) + if type(content_type) == "table" then + content_type = content_type[1] + end + if content_type and content_type:find("image") then + blob = res.body + img = image_loader.decode_float(blob) + end + end + end + return img, blob +end + +local function apply_denoise1(x) + return reconstract(noise1_model, x, BLOCK_OFFSET) +end +local function apply_denoise2(x) + return reconstract(noise2_model, x, BLOCK_OFFSET) +end +local function apply_scale2x(x) + return reconstract(scale20_model, + iproc.scale(x, x:size(3) * 2.0, x:size(2) * 2.0), + BLOCK_OFFSET) +end +local function cache_do(cache, x, func) + if path.exists(cache) then + return image.load(cache) + else + x = func(x) + image.save(cache, x) + return x + end +end + +local function client_disconnected(handler) + return not(handler.request and + handler.request.connection and + handler.request.connection.stream and + (not handler.request.connection.stream:closed())) +end + +local APIHandler = class("APIHandler", turbo.web.RequestHandler) +function APIHandler:post() + if client_disconnected(self) then + self:set_status(400) + self:write("client disconnected") + return + end + local x, src = get_image(self) + local scale = tonumber(self:get_argument("scale", "0")) + local noise = tonumber(self:get_argument("noise", "0")) + if x and valid_size(x, scale) then + if USE_CACHE and (noise ~= 0 or scale ~= 0) then + local hash = md5.sumhexa(src) + local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png") + local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png") + local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png") + local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png") + local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png") + + if noise == 1 then + x = cache_do(cache_noise1, x, apply_denoise1) + elseif noise == 2 then + x = cache_do(cache_noise2, x, apply_denoise2) + end + if scale == 1 or scale == 2 then + if noise == 1 then + x = cache_do(cache_noise1_scale, x, apply_scale2x) + elseif noise == 2 then + x = cache_do(cache_noise2_scale, x, apply_scale2x) + else + x = cache_do(cache_scale, x, apply_scale2x) + end + if scale == 1 then + x = iproc.scale(x, + math.floor(x:size(3) * (1.6 / 2.0) + 0.5), + math.floor(x:size(2) * (1.6 / 2.0) + 0.5), + "Jinc") + end + end + elseif noise ~= 0 or scale ~= 0 then + if noise == 1 then + x = apply_denose1(x) + elseif noise == 2 then + x = apply_denose2(x) + end + if scale == 1 then + local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)} + x = apply_scale2x(x) + x = iproc.scale(x, x16[1], x16[2], "Jinc") + elseif scale == 2 then + x = apply_scale2x(x) + end + end + local name = uuid() .. ".png" + local blob, len = image_loader.encode_png(x) + + self:set_header("Content-Disposition", string.format('filename="%s"', name)) + self:set_header("Content-Type", "image/png") + self:set_header("Content-Length", string.format("%d", len)) + self:write(ffi.string(blob, len)) + else + if not x then + self:set_status(400) + self:write("ERROR: unsupported image format.") + else + self:set_status(400) + self:write("ERROR: max image size exceeded.") + end + end +end +local FormHandler = class("FormHandler", turbo.web.RequestHandler) +function FormHandler:get() + local lang = self.request.headers:get("Accept-Language") + if lang then + local langs = utils.split(lang, ",") + for i = 1, #langs do + langs[i] = utils.split(langs[i], ";")[1] + end + if langs[1] == "ja" then + self:write(file.read(path.join(ROOT, "assets/index.ja.html"))) + else + self:write(file.read(path.join(ROOT, "assets/index.html"))) + end + else + self:write(file.read(path.join(ROOT, "assets/index.html"))) + end +end + +local app = turbo.web.Application:new( + { + {"^/$", FormHandler}, + {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")}, + {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")}, + {"^/api$", APIHandler}, + } +) +app:listen(8888, "0.0.0.0", {max_body_size = CURL_MAX_SIZE}) +turbo.ioloop.instance():start()