diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a950567 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +models/*/*.json binary +*.t7 binary diff --git a/.gitignore b/.gitignore index 25859df..47e9a0a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,15 @@ *~ +work/ cache/*.png -models/*.png +cache/url_* +data/ +!data/.gitkeep + +models/ +!models/anime_style_art +!models/anime_style_art_rgb +!models/ukbench +models/*/*.png + waifu2x.log + diff --git a/README.md b/README.md index e2a2a30..4cc9230 100644 --- a/README.md +++ b/README.md @@ -19,16 +19,11 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed ## Public AMI ``` -AMI ID: ami-0be01e4f -AMI NAME: waifu2x-server -Instance Type: g2.2xlarge -Region: US West (N.California) -OS: Ubuntu 14.04 -User: ubuntu -Created at: 2015-08-12 +TODO ``` ## Third Party Software + [Third-Party](https://github.com/nagadomi/waifu2x/wiki/Third-Party) ## Dependencies @@ -37,10 +32,12 @@ Created at: 2015-08-12 - NVIDIA GPU ### Platform + - [Torch7](http://torch.ch/) - [NVIDIA CUDA](https://developer.nvidia.com/cuda-toolkit) ### lualocks packages (excludes torch7's default packages) +- lua-csnappy - md5 - uuid - [turbo](https://github.com/kernelsauce/turbo) @@ -57,34 +54,44 @@ See: [NVIDIA CUDA Getting Started Guide for Linux](http://docs.nvidia.com/cuda/c Download [CUDA](http://developer.nvidia.com/cuda-downloads) ``` -sudo dpkg -i cuda-repo-ubuntu1404_7.0-28_amd64.deb +sudo dpkg -i cuda-repo-ubuntu1404_7.5-18_amd64.deb sudo apt-get update sudo apt-get install cuda ``` +#### Install Package + +``` +sudo apt-get install libsnappy-dev +``` + #### Install Torch7 See: [Getting started with Torch](http://torch.ch/docs/getting-started.html) +And install luarocks packages. +``` +luarocks install graphicsmagick # upgrade +luarocks install lua-csnappy +luarocks install md5 +luarocks install uuid +PREFIX=$HOME/torch/install luarocks install turbo # if you need to use web application +``` + +#### Getting waifu2x + +``` +git clone --depth 1 https://github.com/nagadomi/waifu2x.git +``` + #### Validation -Test the waifu2x command line tool. +Testing the waifu2x command line tool. ``` th waifu2x.lua ``` -### Setting Up the Web Application Environment (if you needed) - -#### Install packages - -``` -luarocks install md5 -luarocks install uuid -PREFIX=$HOME/torch/install luarocks install turbo -``` - ## Web Application -Run. ``` th web.lua ``` @@ -114,11 +121,11 @@ th waifu2x.lua -m noise_scale -noise_level 1 -i input_image.png -o output_image. th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.png ``` -See also `images/gen.sh`. +See also `th waifu2x.lua -h`. ### Video Encoding -\* `avconv` is `ffmpeg` on Ubuntu 14.04. +\* `avconv` is alias of `ffmpeg` on Ubuntu 14.04. Extracting images and audio from a video. (range: 00:09:00 ~ 00:12:00) ``` @@ -144,6 +151,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 - ``` ## Training Your Own Model +Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`. ### Data Preparation @@ -151,7 +159,7 @@ Genrating a file list. ``` find /path/to/image/dir -name "*.png" > data/image_list.txt ``` -(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-noise-free-PNG images.) +You should use noise free images. In my case, waifu2x is trained with 6000 high-resolution-noise-free-PNG images. Converting training data. ``` diff --git a/appendix/purge_cache.lua b/appendix/purge_cache.lua index fa81be5..1ac0761 100644 --- a/appendix/purge_cache.lua +++ b/appendix/purge_cache.lua @@ -3,7 +3,15 @@ require 'pl' CACHE_DIR="cache" TTL = 3600 * 24 -local files = dir.getfiles(CACHE_DIR, "*.png") +local files = {} +local image_cache = dir.getfiles(CACHE_DIR, "*.png") +local url_cache = dir.getfiles(CACHE_DIR, "url_*") +for i = 1, #image_cache do + table.insert(files, image_cache[i]) +end +for i = 1, #url_cache do + table.insert(files, url_cache[i]) +end local now = os.time() for i, f in pairs(files) do if now - path.getmtime(f) > TTL then diff --git a/assets/index.html b/assets/index.html index 6a18979..8cde0f5 100644 --- a/assets/index.html +++ b/assets/index.html @@ -2,51 +2,17 @@ - waifu2x - + - +

waifu2x

-
- Fork me on GitHub - +
+ Fork me on GitHub +
en/ja/ru
@@ -66,12 +32,14 @@ Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
-
+
Noise Reduction (expect JPEG Artifact) -
When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.
+
+ When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning. +
Upscaling @@ -82,7 +50,7 @@
-
    +
    • If you are using Firefox, Please press the CTRL+S key to save image. "Save Image" option doesn't work.
diff --git a/assets/index.ja.html b/assets/index.ja.html index 357e82c..5936d14 100644 --- a/assets/index.ja.html +++ b/assets/index.ja.html @@ -2,51 +2,17 @@ - + waifu2x - - +

waifu2x

-
- Fork me on GitHub - +
+ Fork me on GitHub +
en/ja/ru
@@ -66,7 +32,7 @@ 制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
-
+
ノイズ除去 (JPEGノイズを想定) @@ -81,7 +47,7 @@
-
    +
    • なし/なしで入力画像を変換せずに出力する。ブラウザのタブで変換結果を比較したい人用。
    • Firefoxの方は、右クリから画像が保存できないようなので、CTRL+SキーかALTキー後 ファイル - ページを保存 で画像を保存してください。
    diff --git a/assets/index.ru.html b/assets/index.ru.html index b3cdcdd..7d45890 100644 --- a/assets/index.ru.html +++ b/assets/index.ru.html @@ -2,51 +2,18 @@ - + waifu2x - + - +

    waifu2x

    -
    - Fork me on GitHub - +
    + Fork me on GitHub +
    en/ja/ru
    @@ -66,11 +33,11 @@ Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
-
- Устранение шума (артефактов JPEG) +
+ Устранение шума (артефактов JPEG) - +
Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект. Также не рекомендуется сильное устранение шума, оно даёт выгоду только в редких случаях, когда картинка изначально была сильно испорчена.
@@ -82,8 +49,9 @@
-
    +
    • Если Вы используете Firefox, для сохранения изображения Вам придётся нажать Ctrl+S (опция в меню "Сохранить изображение" работать не будет!) +
diff --git a/assets/style.css b/assets/style.css new file mode 100644 index 0000000..5968f60 --- /dev/null +++ b/assets/style.css @@ -0,0 +1,52 @@ +body { + margin: 1em 2em 1em 2em; + background: LightGray; + width: 640px; +} +fieldset { + margin-top: 1em; + margin-bottom: 1em; +} +.about { + position: relative; + display: inline-block; + font-size: 0.9em; + padding: 1em 5px 0.2em 0; +} +.help { + font-size: 0.8em; + margin: 1em 0 0 0; +} +.github-banner { + position:absolute; + display:block; + top:0; + left:540px; + max-height:140px; +} +.github-banner-image { + position: absolute; + display: block; + left: 0; + top: 0; + width: 149px; + height: 149px; + border: 0; +} +.github-banner-link { + position: absolute; + display: block; + left:0; + top:0; + width:149px; + height:130px; +} +.padding-left { + padding-left: 15px; +} +.hide { + display: none; +} +.experimental { + margin-bottom: 1em; +} diff --git a/assets/ui.js b/assets/ui.js new file mode 100644 index 0000000..0d93ff4 --- /dev/null +++ b/assets/ui.js @@ -0,0 +1,80 @@ +$(function (){ + function clear_file() { + var new_file = $("#file").clone(); + new_file.change(clear_url); + $("#file").replaceWith(new_file); + } + function clear_url() { + $("#url").val("") + } + function on_change_style(e) { + $("input[name=style]").parents("label").each( + function (i, elm) { + $(elm).css("font-weight", "normal"); + }); + var checked = $("input[name=style]:checked"); + checked.parents("label").css("font-weight", "bold"); + if (checked.val() == "art") { + $("h1").text("waifu2x"); + } else { + $("h1").html("w/a/ifu2x"); + } + } + function on_change_noise_level(e) + { + $("input[name=noise]").parents("label").each( + function (i, elm) { + $(elm).css("font-weight", "normal"); + }); + var checked = $("input[name=noise]:checked"); + if (checked.val() != 0) { + checked.parents("label").css("font-weight", "bold"); + } + } + function on_change_scale_factor(e) + { + $("input[name=scale]").parents("label").each( + function (i, elm) { + $(elm).css("font-weight", "normal"); + }); + var checked = $("input[name=scale]:checked"); + if (checked.val() != 0) { + checked.parents("label").css("font-weight", "bold"); + } + } + function on_change_white_noise(e) + { + $("input[name=white_noise]").parents("label").each( + function (i, elm) { + $(elm).css("font-weight", "normal"); + }); + var checked = $("input[name=white_noise]:checked"); + if (checked.val() != 0) { + checked.parents("label").css("font-weight", "bold"); + } + } + function on_click_experimental_button(e) + { + if ($(this).hasClass("close")) { + $(".experimental .container").show(); + $(this).removeClass("close"); + } else { + $(".experimental .container").hide(); + $(this).addClass("close"); + } + e.preventDefault(); + e.stopPropagation(); + } + + $("#url").change(clear_file); + $("#file").change(clear_url); + //$("input[name=style]").change(on_change_style); + $("input[name=noise]").change(on_change_noise_level); + $("input[name=scale]").change(on_change_scale_factor); + //$("input[name=white_noise]").change(on_change_white_noise); + //$(".experimental .button").click(on_click_experimental_button) + + //on_change_style(); + on_change_scale_factor(); + on_change_noise_level(); +}) diff --git a/convert_data.lua b/convert_data.lua index 0fb9833..3974e20 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -1,48 +1,47 @@ -require './lib/portable' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path + +require 'pl' require 'image' -local settings = require './lib/settings' -local image_loader = require './lib/image_loader' - -local function count_lines(file) - local fp = io.open(file, "r") - local count = 0 - for line in fp:lines() do - count = count + 1 - end - fp:close() - - return count -end - -local function crop_4x(x) - local w = x:size(3) % 4 - local h = x:size(2) % 4 - return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h) -end +local compression = require 'compression' +local settings = require 'settings' +local image_loader = require 'image_loader' +local iproc = require 'iproc' local function load_images(list) - local count = count_lines(list) - local fp = io.open(list, "r") + local MARGIN = 32 + local lines = utils.split(file.read(list), "\n") local x = {} - local c = 0 - for line in fp:lines() do - local im = crop_4x(image_loader.load_byte(line)) - if im then - if im:size(2) >= settings.crop_size * 2 and im:size(3) >= settings.crop_size * 2 then - table.insert(x, im) - end + for i = 1, #lines do + local line = lines[i] + local im, alpha = image_loader.load_byte(line) + if alpha then + io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line)) else - print("error:" .. line) + im = iproc.crop_mod4(im) + local scale = 1.0 + if settings.random_half then + scale = 2.0 + end + if im then + if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then + table.insert(x, compression.compress(im)) + else + io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN)) + end + else + io.stderr:write(string.format("\n%s: skip: load error.\n", line)) + end end - c = c + 1 - xlua.progress(c, count) - if c % 10 == 0 then + xlua.progress(i, #lines) + if i % 10 == 0 then collectgarbage() end end return x end + +torch.manualSeed(settings.seed) print(settings) local x = load_images(settings.image_list) torch.save(settings.images, x) - diff --git a/cudnn2cunn.lua b/cudnn2cunn.lua deleted file mode 100644 index fa6bfcf..0000000 --- a/cudnn2cunn.lua +++ /dev/null @@ -1,34 +0,0 @@ -require 'cunn' -require 'cudnn' -require 'cutorch' -require './lib/LeakyReLU' -local srcnn = require 'lib/srcnn' - -local function cudnn2cunn(cudnn_model) - local cunn_model = srcnn.waifu2x("y") - local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution") - local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM") - - for i = 1, #from_seq do - local from = from_seq[i] - local to = to_seq[i] - to.weight:copy(from.weight) - to.bias:copy(from.bias) - end - cunn_model:cuda() - cunn_model:evaluate() - return cunn_model -end - -local cmd = torch.CmdLine() -cmd:text() -cmd:text("convert cudnn model to cunn model ") -cmd:text("Options:") -cmd:option("-model", "./model.t7", 'path of cudnn model file') -cmd:option("-iformat", "ascii", 'input format') -cmd:option("-oformat", "ascii", 'output format') - -local opt = cmd:parse(arg) -local cudnn_model = torch.load(opt.model, opt.iformat) -local cunn_model = cudnn2cunn(cudnn_model) -torch.save(opt.model, cunn_model, opt.oformat) diff --git a/data/.gitkeep b/data/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/export_model.lua b/export_model.lua deleted file mode 100644 index bb91a0d..0000000 --- a/export_model.lua +++ /dev/null @@ -1,23 +0,0 @@ --- adapted from https://github.com/marcan/cl-waifu2x -require './lib/portable' -require './lib/LeakyReLU' -local cjson = require "cjson" - -local model = torch.load(arg[1], "ascii") - -local jmodules = {} -local modules = model:findModules("nn.SpatialConvolutionMM") -for i = 1, #modules, 1 do - local module = modules[i] - local jmod = { - kW = module.kW, - kH = module.kH, - nInputPlane = module.nInputPlane, - nOutputPlane = module.nOutputPlane, - bias = torch.totable(module.bias:float()), - weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH)) - } - table.insert(jmodules, jmod) -end - -io.write(cjson.encode(jmodules)) diff --git a/images/gen.sh b/images/gen.sh index df5bf97..b88570c 100755 --- a/images/gen.sh +++ b/images/gen.sh @@ -1,8 +1,7 @@ #!/bin/sh -th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png +th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png -th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png diff --git a/images/lena_waifu2x.png b/images/lena_waifu2x.png index f1d2813..b7faaf3 100644 Binary files a/images/lena_waifu2x.png and b/images/lena_waifu2x.png differ diff --git a/images/lena_waifu2x_ukbench.png b/images/lena_waifu2x_ukbench.png index 899e89f..695dcdf 100644 Binary files a/images/lena_waifu2x_ukbench.png and b/images/lena_waifu2x_ukbench.png differ diff --git a/images/miku_CC_BY-NC_noisy_waifu2x.png b/images/miku_CC_BY-NC_noisy_waifu2x.png index 44c711c..a956238 100644 Binary files a/images/miku_CC_BY-NC_noisy_waifu2x.png and b/images/miku_CC_BY-NC_noisy_waifu2x.png differ diff --git a/images/miku_noisy_waifu2x.png b/images/miku_noisy_waifu2x.png index 51bc8bd..ce0148d 100644 Binary files a/images/miku_noisy_waifu2x.png and b/images/miku_noisy_waifu2x.png differ diff --git a/images/miku_small.png b/images/miku_small.png index 17cd2a9..ee2dd4a 100644 Binary files a/images/miku_small.png and b/images/miku_small.png differ diff --git a/images/miku_small_lanczos3.png b/images/miku_small_lanczos3.png index 3b92e90..01e9f1e 100644 Binary files a/images/miku_small_lanczos3.png and b/images/miku_small_lanczos3.png differ diff --git a/images/miku_small_noisy_waifu2x.png b/images/miku_small_noisy_waifu2x.png index 9c0db8a..dd21365 100644 Binary files a/images/miku_small_noisy_waifu2x.png and b/images/miku_small_noisy_waifu2x.png differ diff --git a/images/miku_small_waifu2x.png b/images/miku_small_waifu2x.png index fe50439..e9788fe 100644 Binary files a/images/miku_small_waifu2x.png and b/images/miku_small_waifu2x.png differ diff --git a/images/slide.odp b/images/slide.odp index 8121eaf..2ba4ecd 100644 Binary files a/images/slide.odp and b/images/slide.odp differ diff --git a/images/slide.png b/images/slide.png index 149f075..6c3af06 100644 Binary files a/images/slide.png and b/images/slide.png differ diff --git a/images/slide_noise_reduction.png b/images/slide_noise_reduction.png index 8ba7922..535cb9c 100644 Binary files a/images/slide_noise_reduction.png and b/images/slide_noise_reduction.png differ diff --git a/images/slide_result.png b/images/slide_result.png index e2b4879..28d73d1 100644 Binary files a/images/slide_result.png and b/images/slide_result.png differ diff --git a/images/slide_upscaling.png b/images/slide_upscaling.png index 137f9dd..2b8894b 100644 Binary files a/images/slide_upscaling.png and b/images/slide_upscaling.png differ diff --git a/lib/ClippedWeightedHuberCriterion.lua b/lib/ClippedWeightedHuberCriterion.lua new file mode 100644 index 0000000..77f83a4 --- /dev/null +++ b/lib/ClippedWeightedHuberCriterion.lua @@ -0,0 +1,39 @@ +-- ref: https://en.wikipedia.org/wiki/Huber_loss +local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion') + +function ClippedWeightedHuberCriterion:__init(w, gamma, clip) + parent.__init(self) + self.clip = clip + self.gamma = gamma or 1.0 + self.weight = w:clone() + self.diff = torch.Tensor() + self.diff_abs = torch.Tensor() + --self.outlier_rate = 0.0 + self.square_loss_buff = torch.Tensor() + self.linear_loss_buff = torch.Tensor() +end +function ClippedWeightedHuberCriterion:updateOutput(input, target) + self.diff:resizeAs(input):copy(input) + self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1] + self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2] + for i = 1, input:size(1) do + self.diff[i]:add(-1, target[i]):cmul(self.weight) + end + self.diff_abs:resizeAs(self.diff):copy(self.diff):abs() + + local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)] + local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)] + local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum() + local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum() + + --self.outlier_rate = linear_targets:nElement() / input:nElement() + self.output = (square_loss + linear_loss) / input:nElement() + return self.output +end +function ClippedWeightedHuberCriterion:updateGradInput(input, target) + local norm = 1.0 / input:nElement() + self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm) + local outlier = torch.ge(self.diff_abs, self.gamma) + self.gradInput[outlier] = torch.sign(self.diff[outlier]) * self.gamma * norm + return self.gradInput +end diff --git a/lib/DepthExpand2x.lua b/lib/DepthExpand2x.lua new file mode 100644 index 0000000..3f28dd5 --- /dev/null +++ b/lib/DepthExpand2x.lua @@ -0,0 +1,77 @@ +if w2nn.DepthExpand2x then + return w2nn.DepthExpand2x +end +local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module') + +function DepthExpand2x:__init() + parent:__init() +end + +function DepthExpand2x:updateOutput(input) + local x = input + -- (batch_size, depth, height, width) + self.shape = x:size() + + assert(self.shape:size() == 4, "input must be 4d tensor") + assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0") + -- (batch_size, width, height, depth) + x = x:transpose(2, 4) + -- (batch_size, width, height * 2, depth / 2) + x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2) + -- (batch_size, height * 2, width, depth / 2) + x = x:transpose(2, 3) + -- (batch_size, height * 2, width * 2, depth / 4) + x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4) + -- (batch_size, depth / 4, height * 2, width * 2) + x = x:transpose(2, 4) + x = x:transpose(3, 4) + self.output:resizeAs(x):copy(x) -- contiguous + + return self.output +end + +function DepthExpand2x:updateGradInput(input, gradOutput) + -- (batch_size, depth / 4, height * 2, width * 2) + local x = gradOutput + -- (batch_size, height * 2, width * 2, depth / 4) + x = x:transpose(2, 4) + x = x:transpose(2, 3) + -- (batch_size, height * 2, width, depth / 2) + x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2) + -- (batch_size, width, height * 2, depth / 2) + x = x:transpose(2, 3) + -- (batch_size, width, height, depth) + x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2]) + -- (batch_size, depth, height, width) + x = x:transpose(2, 4) + + self.gradInput:resizeAs(x):copy(x) + + return self.gradInput +end + +function DepthExpand2x.test() + require 'image' + local function show(x) + local img = torch.Tensor(3, x:size(3), x:size(4)) + img[1]:copy(x[1][1]) + img[2]:copy(x[1][2]) + img[3]:copy(x[1][3]) + image.display(img) + end + local img = image.lena() + local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3)) + for i = 0, img:size(1) * 4 - 1 do + src_index = ((i % 3) + 1) + x[1][i + 1]:copy(img[src_index]) + end + show(x) + + local de2x = w2nn.DepthExpand2x() + out = de2x:forward(x) + show(out) + out = de2x:updateGradInput(x, out) + show(out) +end + +return DepthExpand2x diff --git a/lib/LeakyReLU.lua b/lib/LeakyReLU.lua index 09b4f81..5b27bc9 100644 --- a/lib/LeakyReLU.lua +++ b/lib/LeakyReLU.lua @@ -1,7 +1,8 @@ -if nn.LeakyReLU then - return +if w2nn and w2nn.LeakyReLU then + return w2nn.LeakyReLU end -local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module') + +local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module') function LeakyReLU:__init(negative_scale) parent.__init(self) diff --git a/lib/LeakyReLU_deprecated.lua b/lib/LeakyReLU_deprecated.lua new file mode 100644 index 0000000..4f9fc98 --- /dev/null +++ b/lib/LeakyReLU_deprecated.lua @@ -0,0 +1,31 @@ +if nn.LeakyReLU then + return nn.LeakyReLU +end + +local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module') + +function LeakyReLU:__init(negative_scale) + parent.__init(self) + self.negative_scale = negative_scale or 0.333 + self.negative = torch.Tensor() +end + +function LeakyReLU:updateOutput(input) + self.output:resizeAs(input):copy(input):abs():add(input):div(2) + self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale) + self.output:add(self.negative) + + return self.output +end + +function LeakyReLU:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(gradOutput) + -- filter positive + self.negative:sign():add(1) + torch.cmul(self.gradInput, gradOutput, self.negative) + -- filter negative + self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput) + self.gradInput:add(self.negative) + + return self.gradInput +end diff --git a/lib/WeightedMSECriterion.lua b/lib/WeightedMSECriterion.lua new file mode 100644 index 0000000..bc24b0a --- /dev/null +++ b/lib/WeightedMSECriterion.lua @@ -0,0 +1,25 @@ +local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion') + +function WeightedMSECriterion:__init(w) + parent.__init(self) + self.weight = w:clone() + self.diff = torch.Tensor() + self.loss = torch.Tensor() +end + +function WeightedMSECriterion:updateOutput(input, target) + self.diff:resizeAs(input):copy(input) + for i = 1, input:size(1) do + self.diff[i]:add(-1, target[i]):cmul(self.weight) + end + self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff) + self.output = self.loss:mean() + + return self.output +end + +function WeightedMSECriterion:updateGradInput(input, target) + local norm = 2.0 / input:nElement() + self.gradInput:resizeAs(input):copy(self.diff):mul(norm) + return self.gradInput +end diff --git a/cleanup_model.lua b/lib/cleanup_model.lua similarity index 67% rename from cleanup_model.lua rename to lib/cleanup_model.lua index 2f91484..1784992 100644 --- a/cleanup_model.lua +++ b/lib/cleanup_model.lua @@ -1,9 +1,5 @@ -require './lib/portable' -require './lib/LeakyReLU' - -torch.setdefaulttensortype("torch.FloatTensor") - -- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049 + local function zeroDataSize(data) if type(data) == 'table' then for i = 1, #data do @@ -14,7 +10,6 @@ local function zeroDataSize(data) end return data end - -- Resize the output, gradInput, etc temporary tensors to zero (so that the -- on disk size is smaller) local function cleanupModel(node) @@ -27,7 +22,7 @@ local function cleanupModel(node) if node.finput ~= nil then node.finput = zeroDataSize(node.finput) end - if tostring(node) == "nn.LeakyReLU" then + if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then if node.negative ~= nil then node.negative = zeroDataSize(node.negative) end @@ -46,23 +41,8 @@ local function cleanupModel(node) end end end - - collectgarbage() end - -local cmd = torch.CmdLine() -cmd:text() -cmd:text("cleanup model") -cmd:text("Options:") -cmd:option("-model", "./model.t7", 'path of model file') -cmd:option("-iformat", "binary", 'input format') -cmd:option("-oformat", "binary", 'output format') - -local opt = cmd:parse(arg) -local model = torch.load(opt.model, opt.iformat) -if model then +function w2nn.cleanup_model(model) cleanupModel(model) - torch.save(opt.model, model, opt.oformat) -else - error("model not found") + return model end diff --git a/lib/compression.lua b/lib/compression.lua new file mode 100644 index 0000000..50c235e --- /dev/null +++ b/lib/compression.lua @@ -0,0 +1,17 @@ +-- snapply compression for ByteTensor +require 'snappy' + +local compression = {} +compression.compress = function (bt) + local enc = snappy.compress(bt:storage():string()) + return {bt:size(), torch.ByteStorage():string(enc)} +end +compression.decompress = function(data) + local size = data[1] + local dec = snappy.decompress(data[2]:string()) + local bt = torch.ByteTensor(unpack(torch.totable(size))) + bt:storage():string(dec) + return bt +end + +return compression diff --git a/lib/data_augmentation.lua b/lib/data_augmentation.lua new file mode 100644 index 0000000..4d5aaf0 --- /dev/null +++ b/lib/data_augmentation.lua @@ -0,0 +1,104 @@ +require 'image' +local iproc = require 'iproc' + +local data_augmentation = {} + +local function pcacov(x) + local mean = torch.mean(x, 1) + local xm = x - torch.ger(torch.ones(x:size(1)), mean:squeeze()) + local c = torch.mm(xm:t(), xm) + c:div(x:size(1) - 1) + local ce, cv = torch.symeig(c, 'V') + return ce, cv +end +function data_augmentation.color_noise(src, p, factor) + factor = factor or 0.1 + if torch.uniform() < p then + local src, conversion = iproc.byte2float(src) + local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous() + local ce, cv = pcacov(src_t) + local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor) + + pca_space = torch.mm(src_t, cv):t():contiguous() + for i = 1, 3 do + pca_space[i]:mul(color_scale[i]) + end + local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src) + dest[torch.lt(dest, 0.0)] = 0.0 + dest[torch.gt(dest, 1.0)] = 1.0 + + if conversion then + dest = iproc.float2byte(dest) + end + return dest + else + return src + end +end +function data_augmentation.overlay(src, p) + if torch.uniform() < p then + local r = torch.uniform() + local src, conversion = iproc.byte2float(src) + src = src:contiguous() + local flip = data_augmentation.flip(src) + flip:mul(r):add(src * (1.0 - r)) + if conversion then + flip = iproc.float2byte(flip) + end + return flip + else + return src + end +end +function data_augmentation.shift_1px(src) + -- reducing the even/odd issue in nearest neighbor scaler. + local direction = torch.random(1, 4) + local x_shift = 0 + local y_shift = 0 + if direction == 1 then + x_shift = 1 + y_shift = 0 + elseif direction == 2 then + x_shift = 0 + y_shift = 1 + elseif direction == 3 then + x_shift = 1 + y_shift = 1 + elseif flip == 4 then + x_shift = 0 + y_shift = 0 + end + local w = src:size(3) - x_shift + local h = src:size(2) - y_shift + w = w - (w % 4) + h = h - (h % 4) + local dest = iproc.crop(src, x_shift, y_shift, x_shift + w, y_shift + h) + return dest +end +function data_augmentation.flip(src) + local flip = torch.random(1, 4) + local tr = torch.random(1, 2) + local src, conversion = iproc.byte2float(src) + local dest + + src = src:contiguous() + if tr == 1 then + -- pass + elseif tr == 2 then + src = src:transpose(2, 3):contiguous() + end + if flip == 1 then + dest = image.hflip(src) + elseif flip == 2 then + dest = image.vflip(src) + elseif flip == 3 then + dest = image.hflip(image.vflip(src)) + elseif flip == 4 then + dest = src + end + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end +return data_augmentation diff --git a/lib/image_loader.lua b/lib/image_loader.lua index 82e1fbd..9719975 100644 --- a/lib/image_loader.lua +++ b/lib/image_loader.lua @@ -1,74 +1,118 @@ local gm = require 'graphicsmagick' local ffi = require 'ffi' +local iproc = require 'iproc' require 'pl' local image_loader = {} -function image_loader.decode_float(blob) - local im, alpha = image_loader.decode_byte(blob) - if im then - im = im:float():div(255) - end - return im, alpha -end -function image_loader.encode_png(rgb, alpha) - if rgb:type() == "torch.ByteTensor" then - error("expect FloatTensor") - end +local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5) +local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5) +local background_color = 0.5 + +function image_loader.encode_png(rgb, alpha, depth) + depth = depth or 8 + rgb = iproc.byte2float(rgb) if alpha then if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then - alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW") + alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW") end local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3)) rgba[1]:copy(rgb[1]) rgba[2]:copy(rgb[2]) rgba[3]:copy(rgb[3]) rgba[4]:copy(alpha) + + if depth < 16 then + rgba:add(clip_eps8) + rgba[torch.lt(rgba, 0.0)] = 0.0 + rgba[torch.gt(rgba, 1.0)] = 1.0 + else + rgba:add(clip_eps16) + rgba[torch.lt(rgba, 0.0)] = 0.0 + rgba[torch.gt(rgba, 1.0)] = 1.0 + end local im = gm.Image():fromTensor(rgba, "RGBA", "DHW") - im:format("png") - return im:toBlob(9) + return im:depth(depth):format("PNG"):toString(9) else + if depth < 16 then + rgb = rgb:clone():add(clip_eps8) + rgb[torch.lt(rgb, 0.0)] = 0.0 + rgb[torch.gt(rgb, 1.0)] = 1.0 + else + rgb = rgb:clone():add(clip_eps16) + rgb[torch.lt(rgb, 0.0)] = 0.0 + rgb[torch.gt(rgb, 1.0)] = 1.0 + end local im = gm.Image(rgb, "RGB", "DHW") - im:format("png") - return im:toBlob(9) + return im:depth(depth):format("PNG"):toString(9) end end -function image_loader.save_png(filename, rgb, alpha) - local blob, len = image_loader.encode_png(rgb, alpha) +function image_loader.save_png(filename, rgb, alpha, depth) + depth = depth or 8 + local blob = image_loader.encode_png(rgb, alpha, depth) local fp = io.open(filename, "wb") - fp:write(ffi.string(blob, len)) + if not fp then + error("IO error: " .. filename) + end + fp:write(blob) fp:close() return true end -function image_loader.decode_byte(blob) +function image_loader.decode_float(blob) local load_image = function() local im = gm.Image() local alpha = nil + local gamma_lcd = 0.454545 im:fromBlob(blob, #blob) + + if im:colorspace() == "CMYK" then + im:colorspace("RGB") + end + local gamma = math.floor(im:gamma() * 1000000) / 1000000 + if gamma ~= 0 and gamma ~= gamma_lcd then + im:gammaCorrection(gamma / gamma_lcd) + end -- FIXME: How to detect that a image has an alpha channel? if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then -- split alpha channel im = im:toTensor('float', 'RGBA', 'DHW') - local sum_alpha = (im[4] - 1):sum() - if sum_alpha > 0 or sum_alpha < 0 then + local sum_alpha = (im[4] - 1.0):sum() + if sum_alpha < 0 then alpha = im[4]:reshape(1, im:size(2), im:size(3)) + -- drop full transparent background + local mask = torch.le(alpha, 0.0) + im[1][mask] = background_color + im[2][mask] = background_color + im[3][mask] = background_color end local new_im = torch.FloatTensor(3, im:size(2), im:size(3)) new_im[1]:copy(im[1]) new_im[2]:copy(im[2]) new_im[3]:copy(im[3]) - im = new_im:mul(255):byte() + im = new_im else - im = im:toTensor('byte', 'RGB', 'DHW') + im = im:toTensor('float', 'RGB', 'DHW') end - return {im, alpha} + return {im, alpha, blob} end local state, ret = pcall(load_image) if state then - return ret[1], ret[2] + return ret[1], ret[2], ret[3] else - return nil + return nil, nil, nil + end +end +function image_loader.decode_byte(blob) + local im, alpha + im, alpha, blob = image_loader.decode_float(blob) + + if im then + im = iproc.float2byte(im) + -- hmm, alpha does not convert here + return im, alpha, blob + else + return nil, nil, nil end end function image_loader.load_float(file) @@ -90,18 +134,16 @@ function image_loader.load_byte(file) return image_loader.decode_byte(buff) end local function test() - require 'image' - local img - img = image_loader.load_float("./a.jpg") - if img then - print(img:min()) - print(img:max()) - image.display(img) - end - img = image_loader.load_float("./b.png") - if img then - image.display(img) - end + torch.setdefaulttensortype("torch.FloatTensor") + local a = image_loader.load_float("../images/lena.png") + local blob = image_loader.encode_png(a) + local b = image_loader.decode_float(blob) + assert((b - a):abs():sum() == 0) + + a = image_loader.load_byte("../images/lena.png") + blob = image_loader.encode_png(a) + b = image_loader.decode_byte(blob) + assert((b:float() - a:float()):abs():sum() == 0) end --test() return image_loader diff --git a/lib/iproc.lua b/lib/iproc.lua index 9655771..2ea20c4 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -1,16 +1,78 @@ local gm = require 'graphicsmagick' local image = require 'image' -local iproc = {} -function iproc.scale(src, width, height, filter) - local t = "float" - if src:type() == "torch.ByteTensor" then - t = "byte" +local iproc = {} +local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5) + +function iproc.crop_mod4(src) + local w = src:size(3) % 4 + local h = src:size(2) % 4 + return iproc.crop(src, 0, 0, src:size(3) - w, src:size(2) - h) +end +function iproc.crop(src, w1, h1, w2, h2) + local dest + if src:dim() == 3 then + dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]:clone() + else -- dim == 2 + dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]:clone() end + return dest +end +function iproc.crop_nocopy(src, w1, h1, w2, h2) + local dest + if src:dim() == 3 then + dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}] + else -- dim == 2 + dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}] + end + return dest +end +function iproc.byte2float(src) + local conversion = false + local dest = src + if src:type() == "torch.ByteTensor" then + conversion = true + dest = src:float():div(255.0) + end + return dest, conversion +end +function iproc.float2byte(src) + local conversion = false + local dest = src + if src:type() == "torch.FloatTensor" then + conversion = true + dest = (src + clip_eps8):mul(255.0) + dest[torch.lt(dest, 0.0)] = 0 + dest[torch.gt(dest, 255.0)] = 255.0 + dest = dest:byte() + end + return dest, conversion +end +function iproc.scale(src, width, height, filter) + local conversion + src, conversion = iproc.byte2float(src) filter = filter or "Box" local im = gm.Image(src, "RGB", "DHW") im:size(math.ceil(width), math.ceil(height), filter) - return im:toTensor(t, "RGB", "DHW") + local dest = im:toTensor("float", "RGB", "DHW") + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end +function iproc.scale_with_gamma22(src, width, height, filter) + local conversion + src, conversion = iproc.byte2float(src) + filter = filter or "Box" + local im = gm.Image(src, "RGB", "DHW") + im:gammaCorrection(1.0 / 2.2): + size(math.ceil(width), math.ceil(height), filter): + gammaCorrection(2.2) + local dest = im:toTensor("float", "RGB", "DHW") + if conversion then + dest = iproc.float2byte(dest) + end + return dest end function iproc.padding(img, w1, w2, h1, h2) local dst_height = img:size(2) + h1 + h2 @@ -22,5 +84,51 @@ function iproc.padding(img, w1, w2, h1, h2) flow[2]:add(-w1) return image.warp(img, flow, "simple", false, "clamp") end +function iproc.white_noise(src, std, rgb_weights, gamma) + gamma = gamma or 0.454545 + local conversion + src, conversion = iproc.byte2float(src) + std = std or 0.01 + + local noise = torch.Tensor():resizeAs(src):normal(0, std) + if rgb_weights then + noise[1]:mul(rgb_weights[1]) + noise[2]:mul(rgb_weights[2]) + noise[3]:mul(rgb_weights[3]) + end + + local dest + if gamma ~= 0 then + dest = src:clone():pow(gamma):add(noise) + dest[torch.lt(dest, 0.0)] = 0.0 + dest[torch.gt(dest, 1.0)] = 1.0 + dest:pow(1.0 / gamma) + else + dest = src + noise + end + if conversion then + dest = iproc.float2byte(dest) + end + return dest +end + +local function test_conversion() + local a = torch.linspace(0, 255, 256):float():div(255.0) + local b = iproc.float2byte(a) + local c = iproc.byte2float(a) + local d = torch.linspace(0, 255, 256) + assert((a - c):abs():sum() == 0) + assert((d:float() - b:float()):abs():sum() == 0) + + a = torch.FloatTensor({256.0, 255.0, 254.999}):div(255.0) + b = iproc.float2byte(a) + assert(b:float():sum() == 255.0 * 3) + + a = torch.FloatTensor({254.0, 254.499, 253.50001}):div(255.0) + b = iproc.float2byte(a) + print(b) + assert(b:float():sum() == 254.0 * 3) +end +--test_conversion() return iproc diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index 49e16eb..915f1b4 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -21,20 +21,15 @@ local function minibatch_adam(model, criterion, input_size[1], input_size[2], input_size[3]) local targets_tmp = torch.Tensor(batch_size, target_size[1] * target_size[2] * target_size[3]) - - for t = 1, #train_x, batch_size do - if t + batch_size > #train_x then - break - end + for t = 1, #train_x do xlua.progress(t, #train_x) - for i = 1, batch_size do - local x, y = transformer(train_x[shuffle[t + i - 1]]) - inputs_tmp[i]:copy(x) - targets_tmp[i]:copy(y) + local xy = transformer(train_x[shuffle[t]], false, batch_size) + for i = 1, #xy do + inputs_tmp[i]:copy(xy[i][1]) + targets_tmp[i]:copy(xy[i][2]) end inputs:copy(inputs_tmp) targets:copy(targets_tmp) - local feval = function(x) if x ~= parameters then parameters:copy(x) @@ -50,13 +45,13 @@ local function minibatch_adam(model, criterion, optim.adam(feval, parameters, config) c = c + 1 - if c % 10 == 0 then + if c % 20 == 0 then collectgarbage() end end xlua.progress(#train_x, #train_x) - return { mse = sum_loss / count_loss} + return { loss = sum_loss / count_loss} end return minibatch_adam diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index eb81bdf..a14d0a2 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -1,69 +1,80 @@ require 'image' local gm = require 'graphicsmagick' -local iproc = require './iproc' -local reconstruct = require './reconstruct' +local iproc = require 'iproc' +local data_augmentation = require 'data_augmentation' + local pairwise_transform = {} -local function random_half(src, p, min_size) - p = p or 0.5 - local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)] - if p > torch.uniform() then +local function random_half(src, p) + if torch.uniform() < p then + local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)] return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) else return src end end -local function color_augment(x) - local color_scale = torch.Tensor(3):uniform(0.8, 1.2) - x = x:float():div(255) - for i = 1, 3 do - x[i]:mul(color_scale[i]) - end - x[torch.lt(x, 0.0)] = 0.0 - x[torch.gt(x, 1.0)] = 1.0 - return x:mul(255):byte() -end -local function flip_augment(x, y) - local flip = torch.random(1, 4) - if y then - if flip == 1 then - x = image.hflip(x) - y = image.hflip(y) - elseif flip == 2 then - x = image.vflip(x) - y = image.vflip(y) - elseif flip == 3 then - x = image.hflip(image.vflip(x)) - y = image.hflip(image.vflip(y)) - elseif flip == 4 then +local function crop_if_large(src, max_size) + local tries = 4 + if src:size(2) > max_size and src:size(3) > max_size then + local rect + for i = 1, tries do + local yi = torch.random(0, src:size(2) - max_size) + local xi = torch.random(0, src:size(3) - max_size) + rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size) + -- ignore simple background + if rect:float():std() >= 0 then + break + end end - return x, y + return rect else - if flip == 1 then - x = image.hflip(x) - elseif flip == 2 then - x = image.vflip(x) - elseif flip == 3 then - x = image.hflip(image.vflip(x)) - elseif flip == 4 then - end - return x + return src end end -local INTERPOLATION_PADDING = 16 -function pairwise_transform.scale(src, scale, size, offset, options) - options = options or {color_augment = true, random_half = true, rgb = true} - if options.random_half then - src = random_half(src) +local function preprocess(src, crop_size, options) + local dest = src + dest = random_half(dest, options.random_half_rate) + dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size)) + dest = data_augmentation.flip(dest) + dest = data_augmentation.color_noise(dest, options.random_color_noise_rate) + dest = data_augmentation.overlay(dest, options.random_overlay_rate) + dest = data_augmentation.shift_1px(dest) + + return dest +end +local function active_cropping(x, y, size, p, tries) + assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3)) + local r = torch.uniform() + if p < r then + local xi = torch.random(0, y:size(3) - (size + 1)) + local yi = torch.random(0, y:size(2) - (size + 1)) + local xc = iproc.crop(x, xi, yi, xi + size, yi + size) + local yc = iproc.crop(y, xi, yi, xi + size, yi + size) + return xc, yc + else + local best_se = 0.0 + local best_xc, best_yc + local m = torch.FloatTensor(x:size(1), size, size) + for i = 1, tries do + local xi = torch.random(0, y:size(3) - (size + 1)) + local yi = torch.random(0, y:size(2) - (size + 1)) + local xc = iproc.crop(x, xi, yi, xi + size, yi + size) + local yc = iproc.crop(y, xi, yi, xi + size, yi + size) + local xcf = iproc.byte2float(xc) + local ycf = iproc.byte2float(yc) + local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum() + if se >= best_se then + best_xc = xcf + best_yc = ycf + best_se = se + end + end + return best_xc, best_yc end - local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING) - local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING) - local down_scale = 1.0 / scale - local y = image.crop(src, - xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING, - xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING) +end +function pairwise_transform.scale(src, scale, size, offset, n, options) local filters = { - "Box", -- 0.012756949974688 + "Box","Box", -- 0.012756949974688 "Blackman", -- 0.013191924552285 --"Cartom", -- 0.013753536746706 --"Hanning", -- 0.013761314529647 @@ -71,221 +82,173 @@ function pairwise_transform.scale(src, scale, size, offset, options) "SincFast", -- 0.014095824314306 "Jinc", -- 0.014244299255442 } + local unstable_region_offset = 8 local downscale_filter = filters[torch.random(1, #filters)] - - y = flip_augment(y) - if options.color_augment then - y = color_augment(y) - end - local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) - x = iproc.scale(x, y:size(3), y:size(2)) - y = y:float():div(255) - x = x:float():div(255) - - if options.rgb then - else - y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) - x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) - end - - y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING) - x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING) - - return x, y -end -function pairwise_transform.jpeg_(src, quality, size, offset, options) - options = options or {color_augment = true, random_half = true, rgb = true} - if options.random_half then - src = random_half(src) - end - local yi = torch.random(0, src:size(2) - size - 1) - local xi = torch.random(0, src:size(3) - size - 1) - local y = src - local x - - if options.color_augment then - y = color_augment(y) - end - x = y - for i = 1, #quality do - x = gm.Image(x, "RGB", "DHW") - x:format("jpeg") - x:samplingFactors({1.0, 1.0, 1.0}) - local blob, len = x:toBlob(quality[i]) - x:fromBlob(blob, len) - x = x:toTensor("byte", "RGB", "DHW") - end - - y = image.crop(y, xi, yi, xi + size, yi + size) - x = image.crop(x, xi, yi, xi + size, yi + size) - y = y:float():div(255) - x = x:float():div(255) - x, y = flip_augment(x, y) - - if options.rgb then - else - y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) - x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) - end - - return x, image.crop(y, offset, offset, size - offset, size - offset) -end -function pairwise_transform.jpeg(src, level, size, offset, options) - if level == 1 then - return pairwise_transform.jpeg_(src, {torch.random(65, 85)}, - size, offset, - options) - elseif level == 2 then - local r = torch.uniform() - if r > 0.6 then - return pairwise_transform.jpeg_(src, {torch.random(27, 70)}, - size, offset, - options) - elseif r > 0.3 then - local quality1 = torch.random(37, 70) - local quality2 = quality1 - torch.random(5, 10) - return pairwise_transform.jpeg_(src, {quality1, quality2}, - size, offset, - options) - else - local quality1 = torch.random(52, 70) - return pairwise_transform.jpeg_(src, - {quality1, - quality1 - torch.random(5, 15), - quality1 - torch.random(15, 25)}, - size, offset, - options) - end - else - error("unknown noise level: " .. level) - end -end -function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options) - if options.random_half then - src = random_half(src) - end + local y = preprocess(src, size, options) + assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0) local down_scale = 1.0 / scale - local filters = { - "Box", -- 0.012756949974688 - "Blackman", -- 0.013191924552285 - --"Cartom", -- 0.013753536746706 - --"Hanning", -- 0.013761314529647 - --"Hermite", -- 0.013850225205266 - "SincFast", -- 0.014095824314306 - "Jinc", -- 0.014244299255442 - } - local downscale_filter = filters[torch.random(1, #filters)] - local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING) - local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING) - local y = src - local x + local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale, + y:size(2) * down_scale, downscale_filter), + y:size(3), y:size(2)) + x = iproc.crop(x, unstable_region_offset, unstable_region_offset, + x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) + y = iproc.crop(y, unstable_region_offset, unstable_region_offset, + y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) + assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) + assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) - if options.color_augment then - y = color_augment(y) + local batch = {} + for i = 1, n do + local xc, yc = active_cropping(x, y, + size, + options.active_cropping_rate, + options.active_cropping_tries) + xc = iproc.byte2float(xc) + yc = iproc.byte2float(yc) + if options.rgb then + else + yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) + xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + end + table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) end - x = y - x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter) + return batch +end +function pairwise_transform.jpeg_(src, quality, size, offset, n, options) + local unstable_region_offset = 8 + local y = preprocess(src, size, options) + local x = y + for i = 1, #quality do x = gm.Image(x, "RGB", "DHW") - x:format("jpeg") - x:samplingFactors({1.0, 1.0, 1.0}) + x:format("jpeg"):depth(8) + if options.jpeg_sampling_factors == 444 then + x:samplingFactors({1.0, 1.0, 1.0}) + else -- 420 + x:samplingFactors({2.0, 1.0, 1.0}) + end local blob, len = x:toBlob(quality[i]) x:fromBlob(blob, len) x = x:toTensor("byte", "RGB", "DHW") end - x = iproc.scale(x, y:size(3), y:size(2)) - y = image.crop(y, - xi, yi, - xi + size, yi + size) - x = image.crop(x, - xi, yi, - xi + size, yi + size) - x = x:float():div(255) - y = y:float():div(255) - x, y = flip_augment(x, y) - - if options.rgb then - else - y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3)) - x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3)) - end + x = iproc.crop(x, unstable_region_offset, unstable_region_offset, + x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) + y = iproc.crop(y, unstable_region_offset, unstable_region_offset, + y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) + assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) + assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) - return x, image.crop(y, offset, offset, size - offset, size - offset) -end -function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options) - options = options or {color_augment = true, random_half = true} - if level == 1 then - return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)}, - size, offset, options) - elseif level == 2 then - local r = torch.uniform() - if r > 0.6 then - return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)}, - size, offset, options) - elseif r > 0.3 then - local quality1 = torch.random(37, 70) - local quality2 = quality1 - torch.random(5, 10) - return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2}, - size, offset, options) + local batch = {} + for i = 1, n do + local xc, yc = active_cropping(x, y, size, + options.active_cropping_rate, + options.active_cropping_tries) + xc = iproc.byte2float(xc) + yc = iproc.byte2float(yc) + if options.rgb then else - local quality1 = torch.random(52, 70) - return pairwise_transform.jpeg_scale_(src, scale, - {quality1, - quality1 - torch.random(5, 15), - quality1 - torch.random(15, 25)}, - size, offset, options) + yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3)) + xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3)) + end + if torch.uniform() < options.nr_rate then + -- reducing noise + table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) + else + -- ratain useful details + table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)}) + end + end + return batch +end +function pairwise_transform.jpeg(src, style, level, size, offset, n, options) + if style == "art" then + if level == 1 then + return pairwise_transform.jpeg_(src, {torch.random(65, 85)}, + size, offset, n, options) + elseif level == 2 then + local r = torch.uniform() + if r > 0.6 then + return pairwise_transform.jpeg_(src, {torch.random(27, 70)}, + size, offset, n, options) + elseif r > 0.3 then + local quality1 = torch.random(37, 70) + local quality2 = quality1 - torch.random(5, 10) + return pairwise_transform.jpeg_(src, {quality1, quality2}, + size, offset, n, options) + else + local quality1 = torch.random(52, 70) + local quality2 = quality1 - torch.random(5, 15) + local quality3 = quality1 - torch.random(15, 25) + + return pairwise_transform.jpeg_(src, + {quality1, quality2, quality3}, + size, offset, n, options) + end + else + error("unknown noise level: " .. level) + end + elseif style == "photo" then + if level == 1 then + return pairwise_transform.jpeg_(src, {torch.random(30, 75)}, + size, offset, n, + options) + elseif level == 2 then + if torch.uniform() > 0.6 then + return pairwise_transform.jpeg_(src, {torch.random(30, 60)}, + size, offset, n, options) + else + local quality1 = torch.random(40, 60) + local quality2 = quality1 - torch.random(5, 10) + return pairwise_transform.jpeg_(src, {quality1, quality2}, + size, offset, n, options) + end + else + error("unknown noise level: " .. level) end else - error("unknown noise level: " .. level) + error("unknown style: " .. style) end end -local function test_jpeg() - local loader = require './image_loader' - local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") - local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false) - image.display({image = y, legend = "y:0"}) - image.display({image = x, legend = "x:0"}) - for i = 2, 9 do - local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src), - {i * 10}, 128, 0, {color_augment = false, random_half = true}) - image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0}) - image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0}) - --print(x:mean(), y:mean()) +function pairwise_transform.test_jpeg(src) + torch.setdefaulttensortype("torch.FloatTensor") + local options = {random_color_noise_rate = 0.5, + random_half_rate = 0.5, + random_overlay_rate = 0.5, + nr_rate = 1.0, + active_cropping_rate = 0.5, + active_cropping_tries = 10, + max_size = 256, + rgb = true + } + local image = require 'image' + local src = image.lena() + for i = 1, 9 do + local xy = pairwise_transform.jpeg(src, + "art", + torch.random(1, 2), + 128, 7, 1, options) + image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1}) + image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1}) end end +function pairwise_transform.test_scale(src) + torch.setdefaulttensortype("torch.FloatTensor") + local options = {random_color_noise_rate = 0.5, + random_half_rate = 0.5, + random_overlay_rate = 0.5, + active_cropping_rate = 0.5, + active_cropping_tries = 10, + max_size = 256, + rgb = true + } + local image = require 'image' + local src = image.lena() -local function test_scale() - local loader = require './image_loader' - local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") - for i = 1, 9 do - local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true}) - image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1}) - image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1}) - print(y:size(), x:size()) - --print(x:mean(), y:mean()) + for i = 1, 10 do + local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options) + image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1}) + image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1}) end end -local function test_jpeg_scale() - local loader = require './image_loader' - local src = loader.load_byte("../images/miku_CC_BY-NC.jpg") - for i = 1, 9 do - local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_augment = true, random_half = true}) - image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1}) - image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1}) - print(y:size(), x:size()) - --print(x:mean(), y:mean()) - end - for i = 1, 9 do - local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_augment = true, random_half = true}) - image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1}) - image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1}) - print(y:size(), x:size()) - --print(x:mean(), y:mean()) - end -end ---test_scale() ---test_jpeg() ---test_jpeg_scale() - return pairwise_transform diff --git a/lib/portable.lua b/lib/portable.lua deleted file mode 100644 index 36a5264..0000000 --- a/lib/portable.lua +++ /dev/null @@ -1,4 +0,0 @@ -require 'torch' -require 'cutorch' -require 'nn' -require 'cunn' diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 735b8fd..6dce420 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -1,5 +1,5 @@ require 'image' -local iproc = require './iproc' +local iproc = require 'iproc' local function reconstruct_y(model, x, offset, block_size) if x:dim() == 2 then @@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size) end return new_x end -function model_is_rgb(model) +local reconstruct = {} +function reconstruct.is_rgb(model) if model:get(model:size() - 1).weight:size(1) == 3 then -- 3ch RGB return true @@ -57,8 +58,23 @@ function model_is_rgb(model) return false end end - -local reconstruct = {} +function reconstruct.offset_size(model) + local conv = model:findModules("nn.SpatialConvolutionMM") + if #conv > 0 then + local offset = 0 + for i = 1, #conv do + offset = offset + (conv[i].kW - 1) / 2 + end + return math.floor(offset) + else + conv = model:findModules("cudnn.SpatialConvolution") + local offset = 0 + for i = 1, #conv do + offset = offset + (conv[i].kW - 1) / 2 + end + return math.floor(offset) + end +end function reconstruct.image_y(model, x, offset, block_size) block_size = block_size or 128 local output_size = block_size - offset * 2 @@ -78,7 +94,7 @@ function reconstruct.image_y(model, x, offset, block_size) y[torch.lt(y, 0)] = 0 y[torch.gt(y, 1)] = 1 yuv[1]:copy(y) - local output = image.yuv2rgb(image.crop(yuv, + local output = image.yuv2rgb(iproc.crop(yuv, pad_w1, pad_h1, yuv:size(3) - pad_w2, yuv:size(2) - pad_h2)) output[torch.lt(output, 0)] = 0 @@ -110,7 +126,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size) y[torch.lt(y, 0)] = 0 y[torch.gt(y, 1)] = 1 yuv_jinc[1]:copy(y) - local output = image.yuv2rgb(image.crop(yuv_jinc, + local output = image.yuv2rgb(iproc.crop(yuv_jinc, pad_w1, pad_h1, yuv_jinc:size(3) - pad_w2, yuv_jinc:size(2) - pad_h2)) output[torch.lt(output, 0)] = 0 @@ -135,7 +151,7 @@ function reconstruct.image_rgb(model, x, offset, block_size) local pad_w2 = (w - offset) - x:size(3) local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2) local y = reconstruct_rgb(model, input, offset, block_size) - local output = image.crop(y, + local output = iproc.crop(y, pad_w1, pad_h1, y:size(3) - pad_w2, y:size(2) - pad_h2) collectgarbage() @@ -162,7 +178,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size) local pad_w2 = (w - offset) - x:size(3) local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2) local y = reconstruct_rgb(model, input, offset, block_size) - local output = image.crop(y, + local output = iproc.crop(y, pad_w1, pad_h1, y:size(3) - pad_w2, y:size(2) - pad_h2) output[torch.lt(output, 0)] = 0 @@ -172,18 +188,81 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size) return output end -function reconstruct.image(model, x, offset, block_size) - if model_is_rgb(model) then - return reconstruct.image_rgb(model, x, offset, block_size) +function reconstruct.image(model, x, block_size) + if reconstruct.is_rgb(model) then + return reconstruct.image_rgb(model, x, + reconstruct.offset_size(model), block_size) else - return reconstruct.image_y(model, x, offset, block_size) + return reconstruct.image_y(model, x, + reconstruct.offset_size(model), block_size) end end -function reconstruct.scale(model, scale, x, offset, block_size) - if model_is_rgb(model) then - return reconstruct.scale_rgb(model, scale, x, offset, block_size) +function reconstruct.scale(model, scale, x, block_size) + if reconstruct.is_rgb(model) then + return reconstruct.scale_rgb(model, scale, x, + reconstruct.offset_size(model), block_size) else - return reconstruct.scale_y(model, scale, x, offset, block_size) + return reconstruct.scale_y(model, scale, x, + reconstruct.offset_size(model), block_size) + end +end +local function tta(f, model, x, block_size) + local average = nil + local offset = reconstruct.offset_size(model) + for i = 1, 4 do + local flip_f, iflip_f + if i == 1 then + flip_f = function (a) return a end + iflip_f = function (a) return a end + elseif i == 2 then + flip_f = image.vflip + iflip_f = image.vflip + elseif i == 3 then + flip_f = image.hflip + iflip_f = image.hflip + elseif i == 4 then + flip_f = function (a) return image.hflip(image.vflip(a)) end + iflip_f = function (a) return image.vflip(image.hflip(a)) end + end + for j = 1, 2 do + local tr_f, itr_f + if j == 1 then + tr_f = function (a) return a end + itr_f = function (a) return a end + elseif j == 2 then + tr_f = function(a) return a:transpose(2, 3):contiguous() end + itr_f = function(a) return a:transpose(2, 3):contiguous() end + end + local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)), + offset, block_size))) + if not average then + average = out + else + average:add(out) + end + end + end + return average:div(8.0) +end +function reconstruct.image_tta(model, x, block_size) + if reconstruct.is_rgb(model) then + return tta(reconstruct.image_rgb, model, x, block_size) + else + return tta(reconstruct.image_y, model, x, block_size) + end +end +function reconstruct.scale_tta(model, scale, x, block_size) + if reconstruct.is_rgb(model) then + local f = function (model, x, offset, block_size) + return reconstruct.scale_rgb(model, scale, x, offset, block_size) + end + return tta(f, model, x, block_size) + + else + local f = function (model, x, offset, block_size) + return reconstruct.scale_y(model, scale, x, offset, block_size) + end + return tta(f, model, x, block_size) end end diff --git a/lib/settings.lua b/lib/settings.lua index e606868..b527528 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -1,5 +1,6 @@ require 'xlua' require 'pl' +require 'trepl' -- global settings @@ -14,22 +15,34 @@ local settings = {} local cmd = torch.CmdLine() cmd:text() -cmd:text("waifu2x") +cmd:text("waifu2x-training") cmd:text("Options:") -cmd:option("-seed", 11, 'fixed input seed') -cmd:option("-data_dir", "./data", 'data directory') -cmd:option("-test", "images/miku_small.png", 'test image file') +cmd:option("-gpu", -1, 'GPU Device ID') +cmd:option("-seed", 11, 'RNG seed') +cmd:option("-data_dir", "./data", 'path to data directory') +cmd:option("-backend", "cunn", '(cunn|cudnn)') +cmd:option("-test", "images/miku_small.png", 'path to test image') cmd:option("-model_dir", "./models", 'model directory') -cmd:option("-method", "scale", '(noise|scale|noise_scale)') +cmd:option("-method", "scale", 'method to training (noise|scale)') cmd:option("-noise_level", 1, '(1|2)') +cmd:option("-style", "art", '(art|photo)') cmd:option("-color", 'rgb', '(y|rgb)') -cmd:option("-scale", 2.0, 'scale') +cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)') +cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)') +cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)') +cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') -cmd:option("-random_half", 1, 'enable data augmentation using half resolution image') -cmd:option("-crop_size", 128, 'crop size') -cmd:option("-batch_size", 2, 'mini batch size') -cmd:option("-epoch", 200, 'epoch') -cmd:option("-core", 2, 'cpu core') +cmd:option("-crop_size", 46, 'crop size') +cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly') +cmd:option("-batch_size", 8, 'mini batch size') +cmd:option("-epoch", 200, 'number of total epochs to run') +cmd:option("-thread", -1, 'number of CPU threads') +cmd:option("-jpeg_sampling_factors", 444, '(444|420)') +cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)') +cmd:option("-validation_crops", 80, 'number of cropping region per image in validation') +cmd:option("-active_cropping_rate", 0.5, 'active cropping rate') +cmd:option("-active_cropping_tries", 10, 'active cropping tries') +cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)') local opt = cmd:parse(arg) for k, v in pairs(opt) do @@ -53,26 +66,16 @@ end if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then error("scale must be mod-2") end -if settings.random_half == 1 then - settings.random_half = true -else - settings.random_half = false +if not (settings.style == "art" or + settings.style == "photo") then + error(string.format("unknown style: %s", settings.style)) +end + +if settings.thread > 0 then + torch.setnumthreads(tonumber(settings.thread)) end -torch.setnumthreads(settings.core) settings.images = string.format("%s/images.t7", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir) -settings.validation_ratio = 0.1 -settings.validation_crops = 40 - -local srcnn = require './srcnn' -if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then - settings.create_model = srcnn.waifu4x - settings.block_offset = 13 -else - settings.create_model = srcnn.waifu2x - settings.block_offset = 7 -end - return settings diff --git a/lib/srcnn.lua b/lib/srcnn.lua index a282b93..074888d 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -1,74 +1,68 @@ -require './LeakyReLU' +require 'w2nn' -function nn.SpatialConvolutionMM:reset(stdv) - stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane)) - self.weight:normal(0, stdv) - self.bias:fill(0) -end +-- ref: http://arxiv.org/abs/1502.01852 +-- ref: http://arxiv.org/abs/1501.00092 local srcnn = {} -function srcnn.waifu2x(color) +function srcnn.channels(model) + return model:get(model:size() - 1).weight:size(1) +end +function srcnn.waifu2x_cunn(ch) local model = nn.Sequential() - local ch = nil - if color == "rgb" then - ch = 3 - elseif color == "y" then - ch = 1 - else - if color then - error("unknown color: " .. color) - else - error("unknown color: nil") - end - end - model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) + model:add(w2nn.LeakyReLU(0.1)) model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0)) model:add(nn.View(-1):setNumInputDims(3)) ---model:cuda() ---print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size()) + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) - return model, 7 + return model end - --- current 4x is worse than 2x * 2 -function srcnn.waifu4x(color) +function srcnn.waifu2x_cudnn(ch) local model = nn.Sequential() - - local ch = nil + model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0)) + model:add(w2nn.LeakyReLU(0.1)) + model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0)) + model:add(nn.View(-1):setNumInputDims(3)) + --model:cuda() + --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) + + return model +end +function srcnn.create(model_name, backend, color) + local ch = 3 if color == "rgb" then ch = 3 elseif color == "y" then ch = 1 else - error("unknown color: " .. color) + error("unsupported color: " + color) + end + if backend == "cunn" then + return srcnn.waifu2x_cunn(ch) + elseif backend == "cudnn" then + return srcnn.waifu2x_cudnn(ch) + else + error("unsupported backend: " + backend) end - - model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1)) - model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0)) - model:add(nn.View(-1):setNumInputDims(3)) - - return model, 13 end return srcnn diff --git a/lib/w2nn.lua b/lib/w2nn.lua new file mode 100644 index 0000000..e3ac2c8 --- /dev/null +++ b/lib/w2nn.lua @@ -0,0 +1,26 @@ +local function load_nn() + require 'torch' + require 'nn' +end +local function load_cunn() + require 'cutorch' + require 'cunn' +end +local function load_cudnn() + require 'cudnn' + cudnn.benchmark = true +end +if w2nn then + return w2nn +else + pcall(load_cunn) + pcall(load_cudnn) + w2nn = {} + require 'LeakyReLU' + require 'LeakyReLU_deprecated' + require 'DepthExpand2x' + require 'WeightedMSECriterion' + require 'ClippedWeightedHuberCriterion' + require 'cleanup_model' + return w2nn +end diff --git a/models/anime_style_art_rgb/noise1_model.json b/models/anime_style_art_rgb/noise1_model.json index fdb1dea..1481eb7 100644 Binary files a/models/anime_style_art_rgb/noise1_model.json and b/models/anime_style_art_rgb/noise1_model.json differ diff --git a/models/anime_style_art_rgb/noise1_model.t7 b/models/anime_style_art_rgb/noise1_model.t7 index f797451..b1ba5ae 100644 Binary files a/models/anime_style_art_rgb/noise1_model.t7 and b/models/anime_style_art_rgb/noise1_model.t7 differ diff --git a/models/anime_style_art_rgb/noise2_model.json b/models/anime_style_art_rgb/noise2_model.json index 40a9d88..2af44c4 100644 Binary files a/models/anime_style_art_rgb/noise2_model.json and b/models/anime_style_art_rgb/noise2_model.json differ diff --git a/models/anime_style_art_rgb/noise2_model.t7 b/models/anime_style_art_rgb/noise2_model.t7 index acc5be2..cc3b78a 100644 Binary files a/models/anime_style_art_rgb/noise2_model.t7 and b/models/anime_style_art_rgb/noise2_model.t7 differ diff --git a/models/anime_style_art_rgb/scale2.0x_model.json b/models/anime_style_art_rgb/scale2.0x_model.json index 771c9bc..2ad1fbc 100644 Binary files a/models/anime_style_art_rgb/scale2.0x_model.json and b/models/anime_style_art_rgb/scale2.0x_model.json differ diff --git a/models/anime_style_art_rgb/scale2.0x_model.t7 b/models/anime_style_art_rgb/scale2.0x_model.t7 index b368847..c419f06 100644 Binary files a/models/anime_style_art_rgb/scale2.0x_model.t7 and b/models/anime_style_art_rgb/scale2.0x_model.t7 differ diff --git a/models/ukbench/scale2.0x_model.json b/models/ukbench/scale2.0x_model.json index f234d80..cac09dc 100644 Binary files a/models/ukbench/scale2.0x_model.json and b/models/ukbench/scale2.0x_model.json differ diff --git a/models/ukbench/scale2.0x_model.t7 b/models/ukbench/scale2.0x_model.t7 index 97f443e..378625b 100644 Binary files a/models/ukbench/scale2.0x_model.t7 and b/models/ukbench/scale2.0x_model.t7 differ diff --git a/tools/benchmark.lua b/tools/benchmark.lua new file mode 100644 index 0000000..aa07197 --- /dev/null +++ b/tools/benchmark.lua @@ -0,0 +1,169 @@ +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'xlua' +require 'w2nn' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local image_loader = require 'image_loader' +local gm = require 'graphicsmagick' + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x-benchmark") +cmd:text("Options:") + +cmd:option("-dir", "./data/test", 'test image directory') +cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory') +cmd:option("-model2_dir", "", 'model2 directory (optional)') +cmd:option("-method", "scale", '(scale|noise)') +cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)") +cmd:option("-color", "rgb", '(rgb|y)') +cmd:option("-noise_level", 1, 'model noise level') +cmd:option("-jpeg_quality", 75, 'jpeg quality') +cmd:option("-jpeg_times", 1, 'jpeg compression times') +cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times') + +local opt = cmd:parse(arg) +torch.setdefaulttensortype('torch.FloatTensor') +if cudnn then + cudnn.fastest = true + cudnn.benchmark = false +end + +local function MSE(x1, x2) + return (x1 - x2):pow(2):mean() +end +local function YMSE(x1, x2) + local x1_2 = image.rgb2y(x1) + local x2_2 = image.rgb2y(x2) + return (x1_2 - x2_2):pow(2):mean() +end +local function PSNR(x1, x2) + local mse = MSE(x1, x2) + return 10 * math.log10(1.0 / mse) +end +local function YPSNR(x1, x2) + local mse = YMSE(x1, x2) + return 10 * math.log10(1.0 / mse) +end + +local function transform_jpeg(x, opt) + for i = 1, opt.jpeg_times do + jpeg = gm.Image(x, "RGB", "DHW") + jpeg:format("jpeg") + jpeg:samplingFactors({1.0, 1.0, 1.0}) + blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down) + jpeg:fromBlob(blob, len) + x = jpeg:toTensor("byte", "RGB", "DHW") + end + return x +end +local function transform_scale(x, opt) + return iproc.scale(x, + x:size(3) * 0.5, + x:size(2) * 0.5, + opt.filter) +end + +local function benchmark(opt, x, input_func, model1, model2) + local model1_mse = 0 + local model2_mse = 0 + local model1_psnr = 0 + local model2_psnr = 0 + + for i = 1, #x do + local ground_truth = x[i] + local input, model1_output, model2_output + + input = input_func(ground_truth, opt) + input = input:float():div(255) + ground_truth = ground_truth:float():div(255) + + t = sys.clock() + if input:size(3) == ground_truth:size(3) then + model1_output = reconstruct.image(model1, input) + if model2 then + model2_output = reconstruct.image(model2, input) + end + else + model1_output = reconstruct.scale(model1, 2.0, input) + if model2 then + model2_output = reconstruct.scale(model2, 2.0, input) + end + end + if opt.color == "y" then + model1_mse = model1_mse + YMSE(ground_truth, model1_output) + model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output) + if model2 then + model2_mse = model2_mse + YMSE(ground_truth, model2_output) + model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output) + end + elseif opt.color == "rgb" then + model1_mse = model1_mse + MSE(ground_truth, model1_output) + model1_psnr = model1_psnr + PSNR(ground_truth, model1_output) + if model2 then + model2_mse = model2_mse + MSE(ground_truth, model2_output) + model2_psnr = model2_psnr + PSNR(ground_truth, model2_output) + end + else + error("Unknown color: " .. opt.color) + end + if model2 then + io.stdout:write( + string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r", + i, #x, + model1_mse / i, model2_mse / i, + model1_psnr / i, model2_psnr / i + )) + else + io.stdout:write( + string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r", + i, #x, + model1_mse / i, model1_psnr / i + )) + end + io.stdout:flush() + end + io.stdout:write("\n") +end +local function load_data(test_dir) + local test_x = {} + local files = dir.getfiles(test_dir, "*.*") + for i = 1, #files do + table.insert(test_x, iproc.crop_mod4(image_loader.load_byte(files[i]))) + xlua.progress(i, #files) + end + return test_x +end +function load_model(filename) + return torch.load(filename, "ascii") +end +print(opt) +if opt.method == "scale" then + local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7") + local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7") + local s1, model1 = pcall(load_model, f1) + local s2, model2 = pcall(load_model, f2) + if not s1 then + error("Load error: " .. f1) + end + if not s2 then + model2 = nil + end + local test_x = load_data(opt.dir) + benchmark(opt, test_x, transform_scale, model1, model2) +elseif opt.method == "noise" then + local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level)) + local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level)) + local s1, model1 = pcall(load_model, f1) + local s2, model2 = pcall(load_model, f2) + if not s1 then + error("Load error: " .. f1) + end + if not s2 then + model2 = nil + end + local test_x = load_data(opt.dir) + benchmark(opt, test_x, transform_jpeg, model1, model2) +end diff --git a/tools/cleanup_model.lua b/tools/cleanup_model.lua new file mode 100644 index 0000000..934cd6e --- /dev/null +++ b/tools/cleanup_model.lua @@ -0,0 +1,25 @@ +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path + +require 'w2nn' +torch.setdefaulttensortype("torch.FloatTensor") + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("cleanup model") +cmd:text("Options:") +cmd:option("-model", "./model.t7", 'path of model file') +cmd:option("-iformat", "binary", 'input format') +cmd:option("-oformat", "binary", 'output format') + +local opt = cmd:parse(arg) +local model = torch.load(opt.model, opt.iformat) +if model then + w2nn.cleanup_model(model) + model:cuda() + model:evaluate() + torch.save(opt.model, model, opt.oformat) +else + error("model not found") +end diff --git a/tools/cudnn2cunn.lua b/tools/cudnn2cunn.lua new file mode 100644 index 0000000..5c8030e --- /dev/null +++ b/tools/cudnn2cunn.lua @@ -0,0 +1,43 @@ +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'os' +require 'w2nn' +local srcnn = require 'srcnn' + +local function cudnn2cunn(cudnn_model) + local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model)) + local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution") + local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM") + + assert(#weight_from == #weight_to) + + for i = 1, #weight_from do + local from = weight_from[i] + local to = weight_to[i] + + to.weight:copy(from.weight) + to.bias:copy(from.bias) + end + cunn_model:cuda() + cunn_model:evaluate() + return cunn_model +end + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x cudnn model to cunn model converter") +cmd:text("Options:") +cmd:option("-i", "", 'Specify the input cunn model') +cmd:option("-o", "", 'Specify the output cudnn model') +cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)') +cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)') + +local opt = cmd:parse(arg) +if not path.isfile(opt.i) then + cmd:help() + os.exit(-1) +end +local cudnn_model = torch.load(opt.i, opt.iformat) +local cunn_model = cudnn2cunn(cudnn_model) +torch.save(opt.o, cunn_model, opt.oformat) diff --git a/tools/cunn2cudnn.lua b/tools/cunn2cudnn.lua new file mode 100644 index 0000000..d4c198d --- /dev/null +++ b/tools/cunn2cudnn.lua @@ -0,0 +1,43 @@ +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'os' +require 'w2nn' +local srcnn = require 'srcnn' + +local function cunn2cudnn(cunn_model) + local cudnn_model = srcnn.waifu2x_cudnn(srcnn.channels(cunn_model)) + local weight_from = cunn_model:findModules("nn.SpatialConvolutionMM") + local weight_to = cudnn_model:findModules("cudnn.SpatialConvolution") + + assert(#weight_from == #weight_to) + + for i = 1, #weight_from do + local from = weight_from[i] + local to = weight_to[i] + + to.weight:copy(from.weight) + to.bias:copy(from.bias) + end + cudnn_model:cuda() + cudnn_model:evaluate() + return cudnn_model +end + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x cunn model to cudnn model converter") +cmd:text("Options:") +cmd:option("-i", "", 'Specify the input cudnn model') +cmd:option("-o", "", 'Specify the output cunn model') +cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)') +cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)') + +local opt = cmd:parse(arg) +if not path.isfile(opt.i) then + cmd:help() + os.exit(-1) +end +local cunn_model = torch.load(opt.i, opt.iformat) +local cudnn_model = cunn2cudnn(cunn_model) +torch.save(opt.o, cudnn_model, opt.oformat) diff --git a/tools/export_model.lua b/tools/export_model.lua new file mode 100644 index 0000000..8ba10e5 --- /dev/null +++ b/tools/export_model.lua @@ -0,0 +1,54 @@ +-- adapted from https://github.com/marcan/cl-waifu2x +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path +require 'w2nn' +local cjson = require "cjson" + +function export(model, output) + local jmodules = {} + local modules = model:findModules("nn.SpatialConvolutionMM") + if #modules == 0 then + -- cudnn model + modules = model:findModules("cudnn.SpatialConvolution") + end + for i = 1, #modules, 1 do + local module = modules[i] + local jmod = { + kW = module.kW, + kH = module.kH, + nInputPlane = module.nInputPlane, + nOutputPlane = module.nOutputPlane, + bias = torch.totable(module.bias:float()), + weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH)) + } + table.insert(jmodules, jmod) + end + jmodules[1].color = "RGB" + jmodules[1].gamma = 0 + jmodules[#jmodules].color = "RGB" + jmodules[#jmodules].gamma = 0 + + local fp = io.open(output, "w") + if not fp then + error("IO Error: " .. output) + end + fp:write(cjson.encode(jmodules)) + fp:close() +end + +local cmd = torch.CmdLine() +cmd:text() +cmd:text("waifu2x export model") +cmd:text("Options:") +cmd:option("-i", "input.t7", 'Specify the input torch model') +cmd:option("-o", "output.json", 'Specify the output json file') +cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)') + +local opt = cmd:parse(arg) +if not path.isfile(opt.i) then + cmd:help() + os.exit(-1) +end +local model = torch.load(opt.i, opt.iformat) +export(model, opt.o) diff --git a/train.lua b/train.lua index 5e226f2..222c069 100644 --- a/train.lua +++ b/train.lua @@ -1,21 +1,25 @@ -require './lib/portable' +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path require 'optim' require 'xlua' -require 'pl' -local settings = require './lib/settings' -local minibatch_adam = require './lib/minibatch_adam' -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local pairwise_transform = require './lib/pairwise_transform' -local image_loader = require './lib/image_loader' +require 'w2nn' +local settings = require 'settings' +local srcnn = require 'srcnn' +local minibatch_adam = require 'minibatch_adam' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local compression = require 'compression' +local pairwise_transform = require 'pairwise_transform' +local image_loader = require 'image_loader' local function save_test_scale(model, rgb, file) - local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset) + local up = reconstruct.scale(model, settings.scale, rgb) image.save(file, up) end local function save_test_jpeg(model, rgb, file) - local im, count = reconstruct.image(model, rgb, settings.block_offset) + local im, count = reconstruct.image(model, rgb) image.save(file, im) end local function split_data(x, test_size) @@ -31,14 +35,19 @@ local function split_data(x, test_size) end return train_x, valid_x end -local function make_validation_set(x, transformer, n) +local function make_validation_set(x, transformer, n, batch_size) n = n or 4 local data = {} for i = 1, #x do - for k = 1, n do - local x, y = transformer(x[i], true) - table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)), - y = y:reshape(1, y:size(1), y:size(2), y:size(3))}) + for k = 1, math.max(n / batch_size, 1) do + local xy = transformer(x[i], true, batch_size) + local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3)) + local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3)) + for j = 1, #xy do + tx[j]:copy(xy[j][1]) + ty[j]:copy(xy[j][2]) + end + table.insert(data, {x = tx, y = ty}) end xlua.progress(i, #x) collectgarbage() @@ -50,24 +59,92 @@ local function validate(model, criterion, data) for i = 1, #data do local z = model:forward(data[i].x:cuda()) loss = loss + criterion:forward(z, data[i].y:cuda()) - xlua.progress(i, #data) - if i % 10 == 0 then + if i % 100 == 0 then + xlua.progress(i, #data) collectgarbage() end end + xlua.progress(#data, #data) return loss / #data end +local function create_criterion(model) + if reconstruct.is_rgb(model) then + local offset = reconstruct.offset_size(model) + local output_w = settings.crop_size - offset * 2 + local weight = torch.Tensor(3, output_w * output_w) + weight[1]:fill(0.29891 * 3) -- R + weight[2]:fill(0.58661 * 3) -- G + weight[3]:fill(0.11448 * 3) -- B + return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda() + else + return nn.MSECriterion():cuda() + end +end +local function transformer(x, is_validation, n, offset) + x = compression.decompress(x) + n = n or settings.batch_size; + if is_validation == nil then is_validation = false end + local random_color_noise_rate = nil + local random_overlay_rate = nil + local active_cropping_rate = nil + local active_cropping_tries = nil + if is_validation then + active_cropping_rate = 0 + active_cropping_tries = 0 + random_color_noise_rate = 0.0 + random_overlay_rate = 0.0 + else + active_cropping_rate = settings.active_cropping_rate + active_cropping_tries = settings.active_cropping_tries + random_color_noise_rate = settings.random_color_noise_rate + random_overlay_rate = settings.random_overlay_rate + end + + if settings.method == "scale" then + return pairwise_transform.scale(x, + settings.scale, + settings.crop_size, offset, + n, + { + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + max_size = settings.max_size, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + rgb = (settings.color == "rgb") + }) + elseif settings.method == "noise" then + return pairwise_transform.jpeg(x, + settings.style, + settings.noise_level, + settings.crop_size, offset, + n, + { + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + max_size = settings.max_size, + jpeg_sampling_factors = settings.jpeg_sampling_factors, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + nr_rate = settings.nr_rate, + rgb = (settings.color == "rgb") + }) + end +end + local function train() - local model, offset = settings.create_model(settings.color) - assert(offset == settings.block_offset) - local criterion = nn.MSECriterion():cuda() + local model = srcnn.create(settings.method, settings.backend, settings.color) + local offset = reconstruct.offset_size(model) + local pairwise_func = function(x, is_validation, n) + return transformer(x, is_validation, n, offset) + end + local criterion = create_criterion(model) local x = torch.load(settings.images) local lrd_count = 0 - local train_x, valid_x = split_data(x, - math.floor(settings.validation_ratio * #x), - settings.validation_crops) - local test = image_loader.load_float(settings.test) + local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x)) local adam_config = { learningRate = settings.learning_rate, xBatchSize = settings.batch_size, @@ -78,38 +155,11 @@ local function train() elseif settings.color == "rgb" then ch = 3 end - local transformer = function(x, is_validation) - if is_validation == nil then is_validation = false end - if settings.method == "scale" then - return pairwise_transform.scale(x, - settings.scale, - settings.crop_size, offset, - { color_augment = not is_validation, - random_half = settings.random_half, - rgb = (settings.color == "rgb") - }) - elseif settings.method == "noise" then - return pairwise_transform.jpeg(x, - settings.noise_level, - settings.crop_size, offset, - { color_augment = not is_validation, - random_half = settings.random_half, - rgb = (settings.color == "rgb") - }) - elseif settings.method == "noise_scale" then - return pairwise_transform.jpeg_scale(x, - settings.scale, - settings.noise_level, - settings.crop_size, offset, - { color_augment = not is_validation, - random_half = settings.random_half, - rgb = (settings.color == "rgb") - }) - end - end local best_score = 100000.0 print("# make validation-set") - local valid_xy = make_validation_set(valid_x, transformer, 20) + local valid_xy = make_validation_set(valid_x, pairwise_func, + settings.validation_crops, + settings.batch_size) valid_x = nil collectgarbage() @@ -119,7 +169,7 @@ local function train() model:training() print("# " .. epoch) print(minibatch_adam(model, criterion, train_x, adam_config, - transformer, + pairwise_func, {ch, settings.crop_size, settings.crop_size}, {ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2} )) @@ -127,6 +177,7 @@ local function train() print("# validation") local score = validate(model, criterion, valid_xy) if score < best_score then + local test_image = image_loader.load_float(settings.test) -- reload lrd_count = 0 best_score = score print("* update best model") @@ -134,22 +185,17 @@ local function train() if settings.method == "noise" then local log = path.join(settings.model_dir, ("noise%d_best.png"):format(settings.noise_level)) - save_test_jpeg(model, test, log) + save_test_jpeg(model, test_image, log) elseif settings.method == "scale" then local log = path.join(settings.model_dir, ("scale%.1f_best.png"):format(settings.scale)) - save_test_scale(model, test, log) - elseif settings.method == "noise_scale" then - local log = path.join(settings.model_dir, - ("noise%d_scale%.1f_best.png"):format(settings.noise_level, - settings.scale)) - save_test_scale(model, test, log) + save_test_scale(model, test_image, log) end else lrd_count = lrd_count + 1 if lrd_count > 5 then lrd_count = 0 - adam_config.learningRate = adam_config.learningRate * 0.8 + adam_config.learningRate = adam_config.learningRate * 0.9 print("* learning rate decay: " .. adam_config.learningRate) end end @@ -157,6 +203,9 @@ local function train() collectgarbage() end end +if settings.gpu > 0 then + cutorch.setDevice(settings.gpu) +end torch.manualSeed(settings.seed) cutorch.manualSeed(settings.seed) print(settings) diff --git a/train.sh b/train.sh index 2fecb89..79804ea 100755 --- a/train.sh +++ b/train.sh @@ -1,10 +1,12 @@ #!/bin/sh -th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii +th convert_data.lua -th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii +th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4 +th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii -th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png -th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii +th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4 +th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii + +th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4 +th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii diff --git a/train_ukbench.sh b/train_ukbench.sh new file mode 100755 index 0000000..6bdacf3 --- /dev/null +++ b/train_ukbench.sh @@ -0,0 +1,9 @@ +#!/bin/sh + +th convert_data.lua -data_dir ./data/ukbench + +#th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.png -nr_rate 0.9 -jpeg_sampling_factors 420 # -thread 4 -backend cudnn +#th tools/cleanup_model.lua -model models/ukbench/noise2_model.t7 -oformat ascii + +th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg # -thread 4 -backend cudnn +th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii diff --git a/waifu2x.lua b/waifu2x.lua index d033336..a1faf7f 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -1,12 +1,11 @@ -require './lib/portable' -require 'sys' require 'pl' -require './lib/LeakyReLU' - -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local image_loader = require './lib/image_loader' -local BLOCK_OFFSET = 7 +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path +require 'sys' +require 'w2nn' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local image_loader = require 'image_loader' torch.setdefaulttensortype('torch.FloatTensor') @@ -14,43 +13,109 @@ local function convert_image(opt) local x, alpha = image_loader.load_float(opt.i) local new_x = nil local t = sys.clock() + local scale_f, image_f + if opt.tta == 1 then + scale_f = reconstruct.scale_tta + image_f = reconstruct.image_tta + else + scale_f = reconstruct.scale + image_f = reconstruct.image + end if opt.o == "(auto)" then local name = path.basename(opt.i) local e = path.extension(name) local base = name:sub(0, name:len() - e:len()) - opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m)) + opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m)) end if opt.m == "noise" then - local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii") - model:evaluate() - new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size) + local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)) + local model = torch.load(model_path, "ascii") + if not model then + error("Load Error: " .. model_path) + end + new_x = image_f(model, x, opt.crop_size) elseif opt.m == "scale" then - local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") - model:evaluate() - new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + local model = torch.load(model_path, "ascii") + if not model then + error("Load Error: " .. model_path) + end + new_x = scale_f(model, opt.scale, x, opt.crop_size) elseif opt.m == "noise_scale" then - local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii") - local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") - noise_model:evaluate() - scale_model:evaluate() - x = reconstruct.image(noise_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)) + local noise_model = torch.load(noise_model_path, "ascii") + local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + local scale_model = torch.load(scale_model_path, "ascii") + + if not noise_model then + error("Load Error: " .. noise_model_path) + end + if not scale_model then + error("Load Error: " .. scale_model_path) + end + x = image_f(noise_model, x, opt.crop_size) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) else error("undefined method:" .. opt.method) end - image_loader.save_png(opt.o, new_x, alpha) + if opt.white_noise == 1 then + new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0}) + end + image_loader.save_png(opt.o, new_x, alpha, opt.depth) print(opt.o .. ": " .. (sys.clock() - t) .. " sec") end local function convert_frames(opt) - local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii") - local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii") - local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii") - - noise1_model:evaluate() - noise2_model:evaluate() - scale_model:evaluate() - + local model_path, noise1_model, noise2_model, scale_model + local scale_f, image_f + if opt.tta == 1 then + scale_f = reconstruct.scale_tta + image_f = reconstruct.image_tta + else + scale_f = reconstruct.scale + image_f = reconstruct.image + end + if opt.m == "scale" then + model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + scale_model = torch.load(model_path, "ascii") + if not scale_model then + error("Load Error: " .. model_path) + end + elseif opt.m == "noise" and opt.noise_level == 1 then + model_path = path.join(opt.model_dir, "noise1_model.t7") + noise1_model = torch.load(model_path, "ascii") + if not noise1_model then + error("Load Error: " .. model_path) + end + elseif opt.m == "noise" and opt.noise_level == 2 then + model_path = path.join(opt.model_dir, "noise2_model.t7") + noise2_model = torch.load(model_path, "ascii") + if not noise2_model then + error("Load Error: " .. model_path) + end + elseif opt.m == "noise_scale" then + model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)) + scale_model = torch.load(model_path, "ascii") + if not scale_model then + error("Load Error: " .. model_path) + end + if opt.noise_level == 1 then + model_path = path.join(opt.model_dir, "noise1_model.t7") + noise1_model = torch.load(model_path, "ascii") + if not noise1_model then + error("Load Error: " .. model_path) + end + elseif opt.noise_level == 2 then + model_path = path.join(opt.model_dir, "noise2_model.t7") + noise2_model = torch.load(model_path, "ascii") + if not noise2_model then + error("Load Error: " .. model_path) + end + end + end local fp = io.open(opt.l) + if not fp then + error("Open Error: " .. opt.l) + end local count = 0 local lines = {} for line in fp:lines() do @@ -62,20 +127,24 @@ local function convert_frames(opt) local x, alpha = image_loader.load_float(lines[i]) local new_x = nil if opt.m == "noise" and opt.noise_level == 1 then - new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size) + new_x = image_f(noise1_model, x, opt.crop_size) elseif opt.m == "noise" and opt.noise_level == 2 then - new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) + new_x = image_func(noise2_model, x, opt.crop_size) elseif opt.m == "scale" then - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) elseif opt.m == "noise_scale" and opt.noise_level == 1 then - x = reconstruct.image(noise1_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + x = image_f(noise1_model, x, opt.crop_size) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) elseif opt.m == "noise_scale" and opt.noise_level == 2 then - x = reconstruct.image(noise2_model, x, BLOCK_OFFSET) - new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size) + x = image_f(noise2_model, x, opt.crop_size) + new_x = scale_f(scale_model, opt.scale, x, opt.crop_size) else error("undefined method:" .. opt.method) end + if opt.white_noise == 1 then + new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0}) + end + local output = nil if opt.o == "(auto)" then local name = path.basename(lines[i]) @@ -85,7 +154,7 @@ local function convert_frames(opt) else output = string.format(opt.o, i) end - image_loader.save_png(output, new_x, alpha) + image_loader.save_png(output, new_x, alpha, opt.depth) xlua.progress(i, #lines) if i % 10 == 0 then collectgarbage() @@ -101,17 +170,30 @@ local function waifu2x() cmd:text() cmd:text("waifu2x") cmd:text("Options:") - cmd:option("-i", "images/miku_small.png", 'path of the input image') - cmd:option("-l", "", 'path of the image-list') + cmd:option("-i", "images/miku_small.png", 'path to input image') + cmd:option("-l", "", 'path to image-list.txt') cmd:option("-scale", 2, 'scale factor') - cmd:option("-o", "(auto)", 'path of the output file') - cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory') + cmd:option("-o", "(auto)", 'path to output file') + cmd:option("-depth", 8, 'bit-depth of the output image (8|16)') + cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory') cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)') cmd:option("-noise_level", 1, '(1|2)') cmd:option("-crop_size", 128, 'patch size per process') cmd:option("-resume", 0, "skip existing files (0|1)") + cmd:option("-thread", -1, "number of CPU threads") + cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)') + cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)') + cmd:option("-white_noise_std", 0.0055, 'standard division of white noise') local opt = cmd:parse(arg) + if opt.thread > 0 then + torch.setnumthreads(opt.thread) + end + if cudnn then + cudnn.fastest = true + cudnn.benchmark = false + end + if string.len(opt.l) == 0 then convert_image(opt) else diff --git a/web.lua b/web.lua index 63c7a9f..a81ed8e 100644 --- a/web.lua +++ b/web.lua @@ -1,11 +1,21 @@ +require 'pl' +local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() +local ROOT = path.dirname(__FILE__) +package.path = path.join(ROOT, "lib", "?.lua;") .. package.path _G.TURBO_SSL = true -local turbo = require 'turbo' + +require 'w2nn' local uuid = require 'uuid' local ffi = require 'ffi' local md5 = require 'md5' -require 'pl' -require './lib/portable' -require './lib/LeakyReLU' +local iproc = require 'iproc' +local reconstruct = require 'reconstruct' +local image_loader = require 'image_loader' + +-- Notes: turbo and xlua has different implementation of string:split(). +-- Therefore, string:split() has conflict issue. +-- In this script, use turbo's string:split(). +local turbo = require 'turbo' local cmd = torch.CmdLine() cmd:text() @@ -13,24 +23,27 @@ cmd:text("waifu2x-api") cmd:text("Options:") cmd:option("-port", 8812, 'listen port') cmd:option("-gpu", 1, 'Device ID') -cmd:option("-core", 2, 'number of CPU cores') +cmd:option("-thread", -1, 'number of CPU threads') local opt = cmd:parse(arg) cutorch.setDevice(opt.gpu) torch.setdefaulttensortype('torch.FloatTensor') -torch.setnumthreads(opt.core) - -local iproc = require './lib/iproc' -local reconstruct = require './lib/reconstruct' -local image_loader = require './lib/image_loader' - -local MODEL_DIR = "./models/anime_style_art_rgb" - -local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii") -local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii") -local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii") - -local USE_CACHE = true -local CACHE_DIR = "./cache" +if opt.thread > 0 then + torch.setnumthreads(opt.thread) +end +if cudnn then + cudnn.fastest = true + cudnn.benchmark = false +end +local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb") +local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench") +local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii") +local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii") +local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii") +--local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii") +--local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii") +--local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii") +local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag. +local CACHE_DIR = path.join(ROOT, "cache") local MAX_NOISE_IMAGE = 2560 * 2560 local MAX_SCALE_IMAGE = 1280 * 1280 local CURL_OPTIONS = { @@ -40,7 +53,6 @@ local CURL_OPTIONS = { max_redirects = 2 } local CURL_MAX_SIZE = 2 * 1024 * 1024 -local BLOCK_OFFSET = 7 -- see srcnn.lua local function valid_size(x, scale) if scale == 0 then @@ -50,20 +62,16 @@ local function valid_size(x, scale) end end -local function get_image(req) - local file = req:get_argument("file", "") - local url = req:get_argument("url", "") - local blob = nil - local img = nil - local alpha = nil - if file and file:len() > 0 then - blob = file - img, alpha = image_loader.decode_float(blob) - elseif url and url:len() > 0 then +local function cache_url(url) + local hash = md5.sumhexa(url) + local cache_file = path.join(CACHE_DIR, "url_" .. hash) + if path.exists(cache_file) then + return image_loader.load_float(cache_file) + else local res = coroutine.yield( turbo.async.HTTPClient({verify_ca=false}, - nil, - CURL_MAX_SIZE):fetch(url, CURL_OPTIONS) + nil, + CURL_MAX_SIZE):fetch(url, CURL_OPTIONS) ) if res.code == 200 then local content_type = res.headers:get("Content-Type", true) @@ -71,33 +79,64 @@ local function get_image(req) content_type = content_type[1] end if content_type and content_type:find("image") then - blob = res.body - img, alpha = image_loader.decode_float(blob) + local fp = io.open(cache_file, "wb") + local blob = res.body + fp:write(blob) + fp:close() + return image_loader.decode_float(blob) end end end - return img, blob, alpha + return nil, nil, nil end - -local function apply_denoise1(x) - return reconstruct.image(noise1_model, x, BLOCK_OFFSET) +local function get_image(req) + local file = req:get_argument("file", "") + local url = req:get_argument("url", "") + if file and file:len() > 0 then + return image_loader.decode_float(file) + elseif url and url:len() > 0 then + return cache_url(url) + end + return nil, nil, nil end -local function apply_denoise2(x) - return reconstruct.image(noise2_model, x, BLOCK_OFFSET) +local function cleanup_model(model) + if CLEANUP_MODEL then + w2nn.cleanup_model(model) -- release GPU memory + end end -local function apply_scale2x(x) - return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET) -end -local function cache_do(cache, x, func) - if path.exists(cache) then - return image.load(cache) +local function convert(x, options) + local cache_file = path.join(CACHE_DIR, options.prefix .. ".png") + if path.exists(cache_file) then + return image.load(cache_file) else - x = func(x) - image.save(cache, x) + if options.style == "art" then + if options.method == "scale" then + x = reconstruct.scale(art_scale2_model, 2.0, x) + cleanup_model(art_scale2_model) + elseif options.method == "noise1" then + x = reconstruct.image(art_noise1_model, x) + cleanup_model(art_noise1_model) + else -- options.method == "noise2" + x = reconstruct.image(art_noise2_model, x) + cleanup_model(art_noise2_model) + end + else --[[photo + if options.method == "scale" then + x = reconstruct.scale(photo_scale2_model, 2.0, x) + cleanup_model(photo_scale2_model) + elseif options.method == "noise1" then + x = reconstruct.image(photo_noise1_model, x) + cleanup_model(photo_noise1_model) + elseif options.method == "noise2" then + x = reconstruct.image(photo_noise2_model, x) + cleanup_model(photo_noise2_model) + end + --]] + end + image.save(cache_file, x) return x end end - local function client_disconnected(handler) return not(handler.request and handler.request.connection and @@ -112,63 +151,51 @@ function APIHandler:post() self:write("client disconnected") return end - local x, src, alpha = get_image(self) + local x, alpha, blob = get_image(self) local scale = tonumber(self:get_argument("scale", "0")) local noise = tonumber(self:get_argument("noise", "0")) + local white_noise = tonumber(self:get_argument("white_noise", "0")) + local style = self:get_argument("style", "art") + if style ~= "art" then + style = "photo" -- style must be art or photo + end if x and valid_size(x, scale) then - if USE_CACHE and (noise ~= 0 or scale ~= 0) then - local hash = md5.sumhexa(src) - local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png") - local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png") - local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png") - local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png") - local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png") - + if (noise ~= 0 or scale ~= 0) then + local hash = md5.sumhexa(blob) if noise == 1 then - x = cache_do(cache_noise1, x, apply_denoise1) + x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash}) elseif noise == 2 then - x = cache_do(cache_noise2, x, apply_denoise2) + x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash}) end if scale == 1 or scale == 2 then if noise == 1 then - x = cache_do(cache_noise1_scale, x, apply_scale2x) + x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash}) elseif noise == 2 then - x = cache_do(cache_noise2_scale, x, apply_scale2x) + x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash}) else - x = cache_do(cache_scale, x, apply_scale2x) + x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash}) end if scale == 1 then - x = iproc.scale(x, - math.floor(x:size(3) * (1.6 / 2.0) + 0.5), - math.floor(x:size(2) * (1.6 / 2.0) + 0.5), - "Jinc") + x = iproc.scale_with_gamma22(x, + math.floor(x:size(3) * (1.6 / 2.0) + 0.5), + math.floor(x:size(2) * (1.6 / 2.0) + 0.5), + "Jinc") end end - elseif noise ~= 0 or scale ~= 0 then - if noise == 1 then - x = apply_denoise1(x) - elseif noise == 2 then - x = apply_denoise2(x) - end - if scale == 1 then - local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)} - x = apply_scale2x(x) - x = iproc.scale(x, x16[1], x16[2], "Jinc") - elseif scale == 2 then - x = apply_scale2x(x) + if white_noise == 1 then + x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0}) end end local name = uuid() .. ".png" - local blob, len = image_loader.encode_png(x, alpha) - + local blob = image_loader.encode_png(x, alpha) self:set_header("Content-Disposition", string.format('filename="%s"', name)) self:set_header("Content-Type", "image/png") - self:set_header("Content-Length", string.format("%d", len)) - self:write(ffi.string(blob, len)) + self:set_header("Content-Length", string.format("%d", #blob)) + self:write(blob) else if not x then self:set_status(400) - self:write("ERROR: unsupported image format.") + self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)") else self:set_status(400) self:write("ERROR: image size exceeds maximum allowable size.") @@ -177,9 +204,9 @@ function APIHandler:post() collectgarbage() end local FormHandler = class("FormHandler", turbo.web.RequestHandler) -local index_ja = file.read("./assets/index.ja.html") -local index_ru = file.read("./assets/index.ru.html") -local index_en = file.read("./assets/index.html") +local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html")) +local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html")) +local index_en = file.read(path.join(ROOT, "assets", "index.html")) function FormHandler:get() local lang = self.request.headers:get("Accept-Language") if lang then @@ -209,9 +236,11 @@ turbo.log.categories = { local app = turbo.web.Application:new( { {"^/$", FormHandler}, - {"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")}, - {"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")}, - {"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")}, + {"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")}, + {"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")}, + {"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")}, + {"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")}, + {"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")}, {"^/api$", APIHandler}, } )