1
0
Fork 0
mirror of synced 2024-09-28 23:41:59 +12:00

Merge pull request #58 from nagadomi/dev

Sync from development branch
This commit is contained in:
nagadomi 2015-11-15 12:26:49 +09:00
commit 68593d9c51
62 changed files with 1864 additions and 875 deletions

2
.gitattributes vendored Normal file
View file

@ -0,0 +1,2 @@
models/*/*.json binary
*.t7 binary

13
.gitignore vendored
View file

@ -1,4 +1,15 @@
*~
work/
cache/*.png
models/*.png
cache/url_*
data/
!data/.gitkeep
models/
!models/anime_style_art
!models/anime_style_art_rgb
!models/ukbench
models/*/*.png
waifu2x.log

View file

@ -19,16 +19,11 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed
## Public AMI
```
AMI ID: ami-0be01e4f
AMI NAME: waifu2x-server
Instance Type: g2.2xlarge
Region: US West (N.California)
OS: Ubuntu 14.04
User: ubuntu
Created at: 2015-08-12
TODO
```
## Third Party Software
[Third-Party](https://github.com/nagadomi/waifu2x/wiki/Third-Party)
## Dependencies
@ -37,10 +32,12 @@ Created at: 2015-08-12
- NVIDIA GPU
### Platform
- [Torch7](http://torch.ch/)
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-toolkit)
### lualocks packages (excludes torch7's default packages)
- lua-csnappy
- md5
- uuid
- [turbo](https://github.com/kernelsauce/turbo)
@ -57,34 +54,44 @@ See: [NVIDIA CUDA Getting Started Guide for Linux](http://docs.nvidia.com/cuda/c
Download [CUDA](http://developer.nvidia.com/cuda-downloads)
```
sudo dpkg -i cuda-repo-ubuntu1404_7.0-28_amd64.deb
sudo dpkg -i cuda-repo-ubuntu1404_7.5-18_amd64.deb
sudo apt-get update
sudo apt-get install cuda
```
#### Install Package
```
sudo apt-get install libsnappy-dev
```
#### Install Torch7
See: [Getting started with Torch](http://torch.ch/docs/getting-started.html)
And install luarocks packages.
```
luarocks install graphicsmagick # upgrade
luarocks install lua-csnappy
luarocks install md5
luarocks install uuid
PREFIX=$HOME/torch/install luarocks install turbo # if you need to use web application
```
#### Getting waifu2x
```
git clone --depth 1 https://github.com/nagadomi/waifu2x.git
```
#### Validation
Test the waifu2x command line tool.
Testing the waifu2x command line tool.
```
th waifu2x.lua
```
### Setting Up the Web Application Environment (if you needed)
#### Install packages
```
luarocks install md5
luarocks install uuid
PREFIX=$HOME/torch/install luarocks install turbo
```
## Web Application
Run.
```
th web.lua
```
@ -114,11 +121,11 @@ th waifu2x.lua -m noise_scale -noise_level 1 -i input_image.png -o output_image.
th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.png
```
See also `images/gen.sh`.
See also `th waifu2x.lua -h`.
### Video Encoding
\* `avconv` is `ffmpeg` on Ubuntu 14.04.
\* `avconv` is alias of `ffmpeg` on Ubuntu 14.04.
Extracting images and audio from a video. (range: 00:09:00 ~ 00:12:00)
```
@ -144,6 +151,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
```
## Training Your Own Model
Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
### Data Preparation
@ -151,7 +159,7 @@ Genrating a file list.
```
find /path/to/image/dir -name "*.png" > data/image_list.txt
```
(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-noise-free-PNG images.)
You should use noise free images. In my case, waifu2x is trained with 6000 high-resolution-noise-free-PNG images.
Converting training data.
```

View file

@ -3,7 +3,15 @@ require 'pl'
CACHE_DIR="cache"
TTL = 3600 * 24
local files = dir.getfiles(CACHE_DIR, "*.png")
local files = {}
local image_cache = dir.getfiles(CACHE_DIR, "*.png")
local url_cache = dir.getfiles(CACHE_DIR, "url_*")
for i = 1, #image_cache do
table.insert(files, image_cache[i])
end
for i = 1, #url_cache do
table.insert(files, url_cache[i])
end
local now = os.time()
for i, f in pairs(files) do
if now - path.getmtime(f) > TTL then

View file

@ -2,51 +2,17 @@
<html>
<head>
<meta charset="UTF-8">
<link rel="canonical" href="http://waifu2x.udp.jp/">
<title>waifu2x</title>
<style type="text/css">
body {
margin: 1em 2em 1em 2em;
background: LightGray;
width: 640px;
}
fieldset {
margin-top: 1em;
margin-bottom: 1em;
}
.about {
position: relative;
display: inline-block;
font-size: 0.9em;
padding: 1em 5px 0.2em 0;
}
.help {
font-size: 0.85em;
margin: 1em 0 0 0;
}
</style>
<link href="style.css" rel="stylesheet" type="text/css">
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript">
function clear_file() {
var new_file = $("#file").clone();
new_file.change(clear_url);
$("#file").replaceWith(new_file);
}
function clear_url() {
$("#url").val("")
}
$(function (){
$("#url").change(clear_file);
$("#file").change(clear_url);
})
</script>
<script type="text/javascript" src="ui.js"></script>
</head>
<body>
<h1>waifu2x</h1>
<div class="header">
<div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
<img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
<a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
<div class="github-banner">
<img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
<a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
</div>
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
</div>
@ -66,12 +32,14 @@
Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
</div>
</fieldset>
<fieldset>
<fieldset class="noise-field">
<legend>Noise Reduction (expect JPEG Artifact)</legend>
<label><input type="radio" name="noise" value="0"> None</label>
<label><input type="radio" name="noise" value="1" checked="checked"> Medium</label>
<label><input type="radio" name="noise" value="2"> High</label>
<div class="help">When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.</div>
<div class="help">
When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.
</div>
</fieldset>
<fieldset>
<legend>Upscaling</legend>
@ -82,7 +50,7 @@
<input type="submit"/>
</form>
<div class="help">
<ul style="padding-left: 15px;">
<ul class="padding-left">
<li>If you are using Firefox, Please press the CTRL+S key to save image. "Save Image" option doesn't work.
</ul>
</div>

View file

@ -2,51 +2,17 @@
<html lang="ja">
<head>
<meta charset="UTF-8">
<link rel="canonical" href="http://waifu2x.udp.jp/">
<link href="style.css" rel="stylesheet" type="text/css">
<title>waifu2x</title>
<style type="text/css">
body {
margin: 1em 2em 1em 2em;
background: LightGray;
width: 640px;
}
fieldset {
margin-top: 1em;
margin-bottom: 1em;
}
.about {
position: relative;
display: inline-block;
font-size: 0.8em;
padding: 1em 5px 0.2em 0;
}
.help {
font-size: 0.8em;
margin: 1em 0 0 0;
}
</style>
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript">
function clear_file() {
var new_file = $("#file").clone();
new_file.change(clear_url);
$("#file").replaceWith(new_file);
}
function clear_url() {
$("#url").val("")
}
$(function (){
$("#url").change(clear_file);
$("#file").change(clear_url);
})
</script>
<script type="text/javascript" src="ui.js"></script>
</head>
<body>
<h1>waifu2x</h1>
<div class="header">
<div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
<img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
<a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
<div class="github-banner">
<img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
<a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
</div>
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
</div>
@ -66,7 +32,7 @@
制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
</div>
</fieldset>
<fieldset>
<fieldset class="noise-field">
<legend>ノイズ除去 (JPEGイズを想定)</legend>
<label><input type="radio" name="noise" value="0"> なし</label>
<label><input type="radio" name="noise" value="1" checked="checked"></label>
@ -81,7 +47,7 @@
<input type="submit" value="実行"/>
</form>
<div class="help">
<ul style="padding-left: 15px;">
<ul class="padding-left">
<li>なし/なしで入力画像を変換せずに出力する。ブラウザのタブで変換結果を比較したい人用。
<li>Firefoxの方は、右クリから画像が保存できないようなので、CTRL+SキーかALTキー後 ファイル - ページを保存 で画像を保存してください。
</ul>

View file

@ -2,51 +2,18 @@
<html>
<head>
<meta charset="UTF-8">
<link rel="canonical" href="http://waifu2x.udp.jp/">
<link href="style.css" rel="stylesheet" type="text/css">
<title>waifu2x</title>
<style type="text/css">
body {
margin: 1em 2em 1em 2em;
background: LightGray;
width: 640px;
}
fieldset {
margin-top: 1em;
margin-bottom: 1em;
}
.about {
position: relative;
display: inline-block;
font-size: 0.9em;
padding: 1em 5px 0.2em 0;
}
.help {
font-size: 0.85em;
margin: 1em 0 0 0;
}
</style>
<link href="style.css" rel="stylesheet" type="text/css">
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript">
function clear_file() {
var new_file = $("#file").clone();
new_file.change(clear_url);
$("#file").replaceWith(new_file);
}
function clear_url() {
$("#url").val("")
}
$(function (){
$("#url").change(clear_file);
$("#file").change(clear_url);
})
</script>
<script type="text/javascript" src="ui.js"></script>
</head>
<body>
<h1>waifu2x</h1>
<div class="header">
<div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
<img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
<a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
<div class="github-banner">
<img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
<a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
</div>
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
</div>
@ -66,11 +33,11 @@
Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
</div>
</fieldset>
<fieldset>
<legend>Устранение шума (артефактов JPEG)</legend>
<fieldset class="noise-field">
<legend>Устранение шума (артефактов JPEG)</legend>
<label><input type="radio" name="noise" value="0"> Нет</label>
<label><input type="radio" name="noise" value="1" checked="checked"> Средне</label>
<label><input type="radio" name="noise" value="2"> Сильно (не рекомендуется)</label>
<label><input type="radio" name="noise" value="2"> Сильно</label>
<div class="help">Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект. Также не рекомендуется сильное устранение шума, оно даёт выгоду только в редких случаях, когда картинка изначально была сильно испорчена.</div>
</fieldset>
<fieldset>
@ -82,8 +49,9 @@
<input type="submit"/>
</form>
<div class="help">
<ul style="padding-left: 15px;">
<ul class="padding-left">
<li>Если Вы используете Firefox, для сохранения изображения Вам придётся нажать Ctrl+S (опция в меню "Сохранить изображение" работать не будет!)
</li>
</ul>
</div>
</body>

52
assets/style.css Normal file
View file

@ -0,0 +1,52 @@
body {
margin: 1em 2em 1em 2em;
background: LightGray;
width: 640px;
}
fieldset {
margin-top: 1em;
margin-bottom: 1em;
}
.about {
position: relative;
display: inline-block;
font-size: 0.9em;
padding: 1em 5px 0.2em 0;
}
.help {
font-size: 0.8em;
margin: 1em 0 0 0;
}
.github-banner {
position:absolute;
display:block;
top:0;
left:540px;
max-height:140px;
}
.github-banner-image {
position: absolute;
display: block;
left: 0;
top: 0;
width: 149px;
height: 149px;
border: 0;
}
.github-banner-link {
position: absolute;
display: block;
left:0;
top:0;
width:149px;
height:130px;
}
.padding-left {
padding-left: 15px;
}
.hide {
display: none;
}
.experimental {
margin-bottom: 1em;
}

80
assets/ui.js Normal file
View file

@ -0,0 +1,80 @@
$(function (){
function clear_file() {
var new_file = $("#file").clone();
new_file.change(clear_url);
$("#file").replaceWith(new_file);
}
function clear_url() {
$("#url").val("")
}
function on_change_style(e) {
$("input[name=style]").parents("label").each(
function (i, elm) {
$(elm).css("font-weight", "normal");
});
var checked = $("input[name=style]:checked");
checked.parents("label").css("font-weight", "bold");
if (checked.val() == "art") {
$("h1").text("waifu2x");
} else {
$("h1").html("w<s>/a/</s>ifu2x");
}
}
function on_change_noise_level(e)
{
$("input[name=noise]").parents("label").each(
function (i, elm) {
$(elm).css("font-weight", "normal");
});
var checked = $("input[name=noise]:checked");
if (checked.val() != 0) {
checked.parents("label").css("font-weight", "bold");
}
}
function on_change_scale_factor(e)
{
$("input[name=scale]").parents("label").each(
function (i, elm) {
$(elm).css("font-weight", "normal");
});
var checked = $("input[name=scale]:checked");
if (checked.val() != 0) {
checked.parents("label").css("font-weight", "bold");
}
}
function on_change_white_noise(e)
{
$("input[name=white_noise]").parents("label").each(
function (i, elm) {
$(elm).css("font-weight", "normal");
});
var checked = $("input[name=white_noise]:checked");
if (checked.val() != 0) {
checked.parents("label").css("font-weight", "bold");
}
}
function on_click_experimental_button(e)
{
if ($(this).hasClass("close")) {
$(".experimental .container").show();
$(this).removeClass("close");
} else {
$(".experimental .container").hide();
$(this).addClass("close");
}
e.preventDefault();
e.stopPropagation();
}
$("#url").change(clear_file);
$("#file").change(clear_url);
//$("input[name=style]").change(on_change_style);
$("input[name=noise]").change(on_change_noise_level);
$("input[name=scale]").change(on_change_scale_factor);
//$("input[name=white_noise]").change(on_change_white_noise);
//$(".experimental .button").click(on_click_experimental_button)
//on_change_style();
on_change_scale_factor();
on_change_noise_level();
})

View file

@ -1,48 +1,47 @@
require './lib/portable'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
require 'pl'
require 'image'
local settings = require './lib/settings'
local image_loader = require './lib/image_loader'
local function count_lines(file)
local fp = io.open(file, "r")
local count = 0
for line in fp:lines() do
count = count + 1
end
fp:close()
return count
end
local function crop_4x(x)
local w = x:size(3) % 4
local h = x:size(2) % 4
return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
end
local compression = require 'compression'
local settings = require 'settings'
local image_loader = require 'image_loader'
local iproc = require 'iproc'
local function load_images(list)
local count = count_lines(list)
local fp = io.open(list, "r")
local MARGIN = 32
local lines = utils.split(file.read(list), "\n")
local x = {}
local c = 0
for line in fp:lines() do
local im = crop_4x(image_loader.load_byte(line))
if im then
if im:size(2) >= settings.crop_size * 2 and im:size(3) >= settings.crop_size * 2 then
table.insert(x, im)
end
for i = 1, #lines do
local line = lines[i]
local im, alpha = image_loader.load_byte(line)
if alpha then
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
else
print("error:" .. line)
im = iproc.crop_mod4(im)
local scale = 1.0
if settings.random_half then
scale = 2.0
end
if im then
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
table.insert(x, compression.compress(im))
else
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
end
else
io.stderr:write(string.format("\n%s: skip: load error.\n", line))
end
end
c = c + 1
xlua.progress(c, count)
if c % 10 == 0 then
xlua.progress(i, #lines)
if i % 10 == 0 then
collectgarbage()
end
end
return x
end
torch.manualSeed(settings.seed)
print(settings)
local x = load_images(settings.image_list)
torch.save(settings.images, x)

View file

@ -1,34 +0,0 @@
require 'cunn'
require 'cudnn'
require 'cutorch'
require './lib/LeakyReLU'
local srcnn = require 'lib/srcnn'
local function cudnn2cunn(cudnn_model)
local cunn_model = srcnn.waifu2x("y")
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
for i = 1, #from_seq do
local from = from_seq[i]
local to = to_seq[i]
to.weight:copy(from.weight)
to.bias:copy(from.bias)
end
cunn_model:cuda()
cunn_model:evaluate()
return cunn_model
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("convert cudnn model to cunn model ")
cmd:text("Options:")
cmd:option("-model", "./model.t7", 'path of cudnn model file')
cmd:option("-iformat", "ascii", 'input format')
cmd:option("-oformat", "ascii", 'output format')
local opt = cmd:parse(arg)
local cudnn_model = torch.load(opt.model, opt.iformat)
local cunn_model = cudnn2cunn(cudnn_model)
torch.save(opt.model, cunn_model, opt.oformat)

View file

View file

@ -1,23 +0,0 @@
-- adapted from https://github.com/marcan/cl-waifu2x
require './lib/portable'
require './lib/LeakyReLU'
local cjson = require "cjson"
local model = torch.load(arg[1], "ascii")
local jmodules = {}
local modules = model:findModules("nn.SpatialConvolutionMM")
for i = 1, #modules, 1 do
local module = modules[i]
local jmod = {
kW = module.kW,
kH = module.kH,
nInputPlane = module.nInputPlane,
nOutputPlane = module.nOutputPlane,
bias = torch.totable(module.bias:float()),
weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
}
table.insert(jmodules, jmod)
end
io.write(cjson.encode(jmodules))

View file

@ -1,8 +1,7 @@
#!/bin/sh
th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png
th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png
th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png
th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 315 KiB

After

Width:  |  Height:  |  Size: 397 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 605 KiB

After

Width:  |  Height:  |  Size: 651 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 154 KiB

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 177 KiB

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 138 KiB

After

Width:  |  Height:  |  Size: 154 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

After

Width:  |  Height:  |  Size: 162 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 MiB

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 499 KiB

After

Width:  |  Height:  |  Size: 493 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 380 KiB

After

Width:  |  Height:  |  Size: 368 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 378 KiB

After

Width:  |  Height:  |  Size: 352 KiB

View file

@ -0,0 +1,39 @@
-- ref: https://en.wikipedia.org/wiki/Huber_loss
local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion')
function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
parent.__init(self)
self.clip = clip
self.gamma = gamma or 1.0
self.weight = w:clone()
self.diff = torch.Tensor()
self.diff_abs = torch.Tensor()
--self.outlier_rate = 0.0
self.square_loss_buff = torch.Tensor()
self.linear_loss_buff = torch.Tensor()
end
function ClippedWeightedHuberCriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i]):cmul(self.weight)
end
self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]
local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
--self.outlier_rate = linear_targets:nElement() / input:nElement()
self.output = (square_loss + linear_loss) / input:nElement()
return self.output
end
function ClippedWeightedHuberCriterion:updateGradInput(input, target)
local norm = 1.0 / input:nElement()
self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
local outlier = torch.ge(self.diff_abs, self.gamma)
self.gradInput[outlier] = torch.sign(self.diff[outlier]) * self.gamma * norm
return self.gradInput
end

77
lib/DepthExpand2x.lua Normal file
View file

@ -0,0 +1,77 @@
if w2nn.DepthExpand2x then
return w2nn.DepthExpand2x
end
local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module')
function DepthExpand2x:__init()
parent:__init()
end
function DepthExpand2x:updateOutput(input)
local x = input
-- (batch_size, depth, height, width)
self.shape = x:size()
assert(self.shape:size() == 4, "input must be 4d tensor")
assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
-- (batch_size, width, height, depth)
x = x:transpose(2, 4)
-- (batch_size, width, height * 2, depth / 2)
x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
-- (batch_size, height * 2, width, depth / 2)
x = x:transpose(2, 3)
-- (batch_size, height * 2, width * 2, depth / 4)
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
-- (batch_size, depth / 4, height * 2, width * 2)
x = x:transpose(2, 4)
x = x:transpose(3, 4)
self.output:resizeAs(x):copy(x) -- contiguous
return self.output
end
function DepthExpand2x:updateGradInput(input, gradOutput)
-- (batch_size, depth / 4, height * 2, width * 2)
local x = gradOutput
-- (batch_size, height * 2, width * 2, depth / 4)
x = x:transpose(2, 4)
x = x:transpose(2, 3)
-- (batch_size, height * 2, width, depth / 2)
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
-- (batch_size, width, height * 2, depth / 2)
x = x:transpose(2, 3)
-- (batch_size, width, height, depth)
x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
-- (batch_size, depth, height, width)
x = x:transpose(2, 4)
self.gradInput:resizeAs(x):copy(x)
return self.gradInput
end
function DepthExpand2x.test()
require 'image'
local function show(x)
local img = torch.Tensor(3, x:size(3), x:size(4))
img[1]:copy(x[1][1])
img[2]:copy(x[1][2])
img[3]:copy(x[1][3])
image.display(img)
end
local img = image.lena()
local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
for i = 0, img:size(1) * 4 - 1 do
src_index = ((i % 3) + 1)
x[1][i + 1]:copy(img[src_index])
end
show(x)
local de2x = w2nn.DepthExpand2x()
out = de2x:forward(x)
show(out)
out = de2x:updateGradInput(x, out)
show(out)
end
return DepthExpand2x

View file

@ -1,7 +1,8 @@
if nn.LeakyReLU then
return
if w2nn and w2nn.LeakyReLU then
return w2nn.LeakyReLU
end
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module')
function LeakyReLU:__init(negative_scale)
parent.__init(self)

View file

@ -0,0 +1,31 @@
if nn.LeakyReLU then
return nn.LeakyReLU
end
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
function LeakyReLU:__init(negative_scale)
parent.__init(self)
self.negative_scale = negative_scale or 0.333
self.negative = torch.Tensor()
end
function LeakyReLU:updateOutput(input)
self.output:resizeAs(input):copy(input):abs():add(input):div(2)
self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
self.output:add(self.negative)
return self.output
end
function LeakyReLU:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput)
-- filter positive
self.negative:sign():add(1)
torch.cmul(self.gradInput, gradOutput, self.negative)
-- filter negative
self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
self.gradInput:add(self.negative)
return self.gradInput
end

View file

@ -0,0 +1,25 @@
local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
function WeightedMSECriterion:__init(w)
parent.__init(self)
self.weight = w:clone()
self.diff = torch.Tensor()
self.loss = torch.Tensor()
end
function WeightedMSECriterion:updateOutput(input, target)
self.diff:resizeAs(input):copy(input)
for i = 1, input:size(1) do
self.diff[i]:add(-1, target[i]):cmul(self.weight)
end
self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
self.output = self.loss:mean()
return self.output
end
function WeightedMSECriterion:updateGradInput(input, target)
local norm = 2.0 / input:nElement()
self.gradInput:resizeAs(input):copy(self.diff):mul(norm)
return self.gradInput
end

View file

@ -1,9 +1,5 @@
require './lib/portable'
require './lib/LeakyReLU'
torch.setdefaulttensortype("torch.FloatTensor")
-- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
local function zeroDataSize(data)
if type(data) == 'table' then
for i = 1, #data do
@ -14,7 +10,6 @@ local function zeroDataSize(data)
end
return data
end
-- Resize the output, gradInput, etc temporary tensors to zero (so that the
-- on disk size is smaller)
local function cleanupModel(node)
@ -27,7 +22,7 @@ local function cleanupModel(node)
if node.finput ~= nil then
node.finput = zeroDataSize(node.finput)
end
if tostring(node) == "nn.LeakyReLU" then
if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
if node.negative ~= nil then
node.negative = zeroDataSize(node.negative)
end
@ -46,23 +41,8 @@ local function cleanupModel(node)
end
end
end
collectgarbage()
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("cleanup model")
cmd:text("Options:")
cmd:option("-model", "./model.t7", 'path of model file')
cmd:option("-iformat", "binary", 'input format')
cmd:option("-oformat", "binary", 'output format')
local opt = cmd:parse(arg)
local model = torch.load(opt.model, opt.iformat)
if model then
function w2nn.cleanup_model(model)
cleanupModel(model)
torch.save(opt.model, model, opt.oformat)
else
error("model not found")
return model
end

17
lib/compression.lua Normal file
View file

@ -0,0 +1,17 @@
-- snapply compression for ByteTensor
require 'snappy'
local compression = {}
compression.compress = function (bt)
local enc = snappy.compress(bt:storage():string())
return {bt:size(), torch.ByteStorage():string(enc)}
end
compression.decompress = function(data)
local size = data[1]
local dec = snappy.decompress(data[2]:string())
local bt = torch.ByteTensor(unpack(torch.totable(size)))
bt:storage():string(dec)
return bt
end
return compression

104
lib/data_augmentation.lua Normal file
View file

@ -0,0 +1,104 @@
require 'image'
local iproc = require 'iproc'
local data_augmentation = {}
local function pcacov(x)
local mean = torch.mean(x, 1)
local xm = x - torch.ger(torch.ones(x:size(1)), mean:squeeze())
local c = torch.mm(xm:t(), xm)
c:div(x:size(1) - 1)
local ce, cv = torch.symeig(c, 'V')
return ce, cv
end
function data_augmentation.color_noise(src, p, factor)
factor = factor or 0.1
if torch.uniform() < p then
local src, conversion = iproc.byte2float(src)
local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
local ce, cv = pcacov(src_t)
local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor)
pca_space = torch.mm(src_t, cv):t():contiguous()
for i = 1, 3 do
pca_space[i]:mul(color_scale[i])
end
local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
dest[torch.lt(dest, 0.0)] = 0.0
dest[torch.gt(dest, 1.0)] = 1.0
if conversion then
dest = iproc.float2byte(dest)
end
return dest
else
return src
end
end
function data_augmentation.overlay(src, p)
if torch.uniform() < p then
local r = torch.uniform()
local src, conversion = iproc.byte2float(src)
src = src:contiguous()
local flip = data_augmentation.flip(src)
flip:mul(r):add(src * (1.0 - r))
if conversion then
flip = iproc.float2byte(flip)
end
return flip
else
return src
end
end
function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler.
local direction = torch.random(1, 4)
local x_shift = 0
local y_shift = 0
if direction == 1 then
x_shift = 1
y_shift = 0
elseif direction == 2 then
x_shift = 0
y_shift = 1
elseif direction == 3 then
x_shift = 1
y_shift = 1
elseif flip == 4 then
x_shift = 0
y_shift = 0
end
local w = src:size(3) - x_shift
local h = src:size(2) - y_shift
w = w - (w % 4)
h = h - (h % 4)
local dest = iproc.crop(src, x_shift, y_shift, x_shift + w, y_shift + h)
return dest
end
function data_augmentation.flip(src)
local flip = torch.random(1, 4)
local tr = torch.random(1, 2)
local src, conversion = iproc.byte2float(src)
local dest
src = src:contiguous()
if tr == 1 then
-- pass
elseif tr == 2 then
src = src:transpose(2, 3):contiguous()
end
if flip == 1 then
dest = image.hflip(src)
elseif flip == 2 then
dest = image.vflip(src)
elseif flip == 3 then
dest = image.hflip(image.vflip(src))
elseif flip == 4 then
dest = src
end
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
return data_augmentation

View file

@ -1,74 +1,118 @@
local gm = require 'graphicsmagick'
local ffi = require 'ffi'
local iproc = require 'iproc'
require 'pl'
local image_loader = {}
function image_loader.decode_float(blob)
local im, alpha = image_loader.decode_byte(blob)
if im then
im = im:float():div(255)
end
return im, alpha
end
function image_loader.encode_png(rgb, alpha)
if rgb:type() == "torch.ByteTensor" then
error("expect FloatTensor")
end
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
local background_color = 0.5
function image_loader.encode_png(rgb, alpha, depth)
depth = depth or 8
rgb = iproc.byte2float(rgb)
if alpha then
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
end
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
rgba[1]:copy(rgb[1])
rgba[2]:copy(rgb[2])
rgba[3]:copy(rgb[3])
rgba[4]:copy(alpha)
if depth < 16 then
rgba:add(clip_eps8)
rgba[torch.lt(rgba, 0.0)] = 0.0
rgba[torch.gt(rgba, 1.0)] = 1.0
else
rgba:add(clip_eps16)
rgba[torch.lt(rgba, 0.0)] = 0.0
rgba[torch.gt(rgba, 1.0)] = 1.0
end
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
im:format("png")
return im:toBlob(9)
return im:depth(depth):format("PNG"):toString(9)
else
if depth < 16 then
rgb = rgb:clone():add(clip_eps8)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
else
rgb = rgb:clone():add(clip_eps16)
rgb[torch.lt(rgb, 0.0)] = 0.0
rgb[torch.gt(rgb, 1.0)] = 1.0
end
local im = gm.Image(rgb, "RGB", "DHW")
im:format("png")
return im:toBlob(9)
return im:depth(depth):format("PNG"):toString(9)
end
end
function image_loader.save_png(filename, rgb, alpha)
local blob, len = image_loader.encode_png(rgb, alpha)
function image_loader.save_png(filename, rgb, alpha, depth)
depth = depth or 8
local blob = image_loader.encode_png(rgb, alpha, depth)
local fp = io.open(filename, "wb")
fp:write(ffi.string(blob, len))
if not fp then
error("IO error: " .. filename)
end
fp:write(blob)
fp:close()
return true
end
function image_loader.decode_byte(blob)
function image_loader.decode_float(blob)
local load_image = function()
local im = gm.Image()
local alpha = nil
local gamma_lcd = 0.454545
im:fromBlob(blob, #blob)
if im:colorspace() == "CMYK" then
im:colorspace("RGB")
end
local gamma = math.floor(im:gamma() * 1000000) / 1000000
if gamma ~= 0 and gamma ~= gamma_lcd then
im:gammaCorrection(gamma / gamma_lcd)
end
-- FIXME: How to detect that a image has an alpha channel?
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
-- split alpha channel
im = im:toTensor('float', 'RGBA', 'DHW')
local sum_alpha = (im[4] - 1):sum()
if sum_alpha > 0 or sum_alpha < 0 then
local sum_alpha = (im[4] - 1.0):sum()
if sum_alpha < 0 then
alpha = im[4]:reshape(1, im:size(2), im:size(3))
-- drop full transparent background
local mask = torch.le(alpha, 0.0)
im[1][mask] = background_color
im[2][mask] = background_color
im[3][mask] = background_color
end
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
new_im[1]:copy(im[1])
new_im[2]:copy(im[2])
new_im[3]:copy(im[3])
im = new_im:mul(255):byte()
im = new_im
else
im = im:toTensor('byte', 'RGB', 'DHW')
im = im:toTensor('float', 'RGB', 'DHW')
end
return {im, alpha}
return {im, alpha, blob}
end
local state, ret = pcall(load_image)
if state then
return ret[1], ret[2]
return ret[1], ret[2], ret[3]
else
return nil
return nil, nil, nil
end
end
function image_loader.decode_byte(blob)
local im, alpha
im, alpha, blob = image_loader.decode_float(blob)
if im then
im = iproc.float2byte(im)
-- hmm, alpha does not convert here
return im, alpha, blob
else
return nil, nil, nil
end
end
function image_loader.load_float(file)
@ -90,18 +134,16 @@ function image_loader.load_byte(file)
return image_loader.decode_byte(buff)
end
local function test()
require 'image'
local img
img = image_loader.load_float("./a.jpg")
if img then
print(img:min())
print(img:max())
image.display(img)
end
img = image_loader.load_float("./b.png")
if img then
image.display(img)
end
torch.setdefaulttensortype("torch.FloatTensor")
local a = image_loader.load_float("../images/lena.png")
local blob = image_loader.encode_png(a)
local b = image_loader.decode_float(blob)
assert((b - a):abs():sum() == 0)
a = image_loader.load_byte("../images/lena.png")
blob = image_loader.encode_png(a)
b = image_loader.decode_byte(blob)
assert((b:float() - a:float()):abs():sum() == 0)
end
--test()
return image_loader

View file

@ -1,16 +1,78 @@
local gm = require 'graphicsmagick'
local image = require 'image'
local iproc = {}
function iproc.scale(src, width, height, filter)
local t = "float"
if src:type() == "torch.ByteTensor" then
t = "byte"
local iproc = {}
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
function iproc.crop_mod4(src)
local w = src:size(3) % 4
local h = src:size(2) % 4
return iproc.crop(src, 0, 0, src:size(3) - w, src:size(2) - h)
end
function iproc.crop(src, w1, h1, w2, h2)
local dest
if src:dim() == 3 then
dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
else -- dim == 2
dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
end
return dest
end
function iproc.crop_nocopy(src, w1, h1, w2, h2)
local dest
if src:dim() == 3 then
dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
else -- dim == 2
dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
end
return dest
end
function iproc.byte2float(src)
local conversion = false
local dest = src
if src:type() == "torch.ByteTensor" then
conversion = true
dest = src:float():div(255.0)
end
return dest, conversion
end
function iproc.float2byte(src)
local conversion = false
local dest = src
if src:type() == "torch.FloatTensor" then
conversion = true
dest = (src + clip_eps8):mul(255.0)
dest[torch.lt(dest, 0.0)] = 0
dest[torch.gt(dest, 255.0)] = 255.0
dest = dest:byte()
end
return dest, conversion
end
function iproc.scale(src, width, height, filter)
local conversion
src, conversion = iproc.byte2float(src)
filter = filter or "Box"
local im = gm.Image(src, "RGB", "DHW")
im:size(math.ceil(width), math.ceil(height), filter)
return im:toTensor(t, "RGB", "DHW")
local dest = im:toTensor("float", "RGB", "DHW")
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
function iproc.scale_with_gamma22(src, width, height, filter)
local conversion
src, conversion = iproc.byte2float(src)
filter = filter or "Box"
local im = gm.Image(src, "RGB", "DHW")
im:gammaCorrection(1.0 / 2.2):
size(math.ceil(width), math.ceil(height), filter):
gammaCorrection(2.2)
local dest = im:toTensor("float", "RGB", "DHW")
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
function iproc.padding(img, w1, w2, h1, h2)
local dst_height = img:size(2) + h1 + h2
@ -22,5 +84,51 @@ function iproc.padding(img, w1, w2, h1, h2)
flow[2]:add(-w1)
return image.warp(img, flow, "simple", false, "clamp")
end
function iproc.white_noise(src, std, rgb_weights, gamma)
gamma = gamma or 0.454545
local conversion
src, conversion = iproc.byte2float(src)
std = std or 0.01
local noise = torch.Tensor():resizeAs(src):normal(0, std)
if rgb_weights then
noise[1]:mul(rgb_weights[1])
noise[2]:mul(rgb_weights[2])
noise[3]:mul(rgb_weights[3])
end
local dest
if gamma ~= 0 then
dest = src:clone():pow(gamma):add(noise)
dest[torch.lt(dest, 0.0)] = 0.0
dest[torch.gt(dest, 1.0)] = 1.0
dest:pow(1.0 / gamma)
else
dest = src + noise
end
if conversion then
dest = iproc.float2byte(dest)
end
return dest
end
local function test_conversion()
local a = torch.linspace(0, 255, 256):float():div(255.0)
local b = iproc.float2byte(a)
local c = iproc.byte2float(a)
local d = torch.linspace(0, 255, 256)
assert((a - c):abs():sum() == 0)
assert((d:float() - b:float()):abs():sum() == 0)
a = torch.FloatTensor({256.0, 255.0, 254.999}):div(255.0)
b = iproc.float2byte(a)
assert(b:float():sum() == 255.0 * 3)
a = torch.FloatTensor({254.0, 254.499, 253.50001}):div(255.0)
b = iproc.float2byte(a)
print(b)
assert(b:float():sum() == 254.0 * 3)
end
--test_conversion()
return iproc

View file

@ -21,20 +21,15 @@ local function minibatch_adam(model, criterion,
input_size[1], input_size[2], input_size[3])
local targets_tmp = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3])
for t = 1, #train_x, batch_size do
if t + batch_size > #train_x then
break
end
for t = 1, #train_x do
xlua.progress(t, #train_x)
for i = 1, batch_size do
local x, y = transformer(train_x[shuffle[t + i - 1]])
inputs_tmp[i]:copy(x)
targets_tmp[i]:copy(y)
local xy = transformer(train_x[shuffle[t]], false, batch_size)
for i = 1, #xy do
inputs_tmp[i]:copy(xy[i][1])
targets_tmp[i]:copy(xy[i][2])
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
local feval = function(x)
if x ~= parameters then
parameters:copy(x)
@ -50,13 +45,13 @@ local function minibatch_adam(model, criterion,
optim.adam(feval, parameters, config)
c = c + 1
if c % 10 == 0 then
if c % 20 == 0 then
collectgarbage()
end
end
xlua.progress(#train_x, #train_x)
return { mse = sum_loss / count_loss}
return { loss = sum_loss / count_loss}
end
return minibatch_adam

View file

@ -1,69 +1,80 @@
require 'image'
local gm = require 'graphicsmagick'
local iproc = require './iproc'
local reconstruct = require './reconstruct'
local iproc = require 'iproc'
local data_augmentation = require 'data_augmentation'
local pairwise_transform = {}
local function random_half(src, p, min_size)
p = p or 0.5
local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
if p > torch.uniform() then
local function random_half(src, p)
if torch.uniform() < p then
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
end
end
local function color_augment(x)
local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
x = x:float():div(255)
for i = 1, 3 do
x[i]:mul(color_scale[i])
end
x[torch.lt(x, 0.0)] = 0.0
x[torch.gt(x, 1.0)] = 1.0
return x:mul(255):byte()
end
local function flip_augment(x, y)
local flip = torch.random(1, 4)
if y then
if flip == 1 then
x = image.hflip(x)
y = image.hflip(y)
elseif flip == 2 then
x = image.vflip(x)
y = image.vflip(y)
elseif flip == 3 then
x = image.hflip(image.vflip(x))
y = image.hflip(image.vflip(y))
elseif flip == 4 then
local function crop_if_large(src, max_size)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return x, y
return rect
else
if flip == 1 then
x = image.hflip(x)
elseif flip == 2 then
x = image.vflip(x)
elseif flip == 3 then
x = image.hflip(image.vflip(x))
elseif flip == 4 then
end
return x
return src
end
end
local INTERPOLATION_PADDING = 16
function pairwise_transform.scale(src, scale, size, offset, options)
options = options or {color_augment = true, random_half = true, rgb = true}
if options.random_half then
src = random_half(src)
local function preprocess(src, crop_size, options)
local dest = src
dest = random_half(dest, options.random_half_rate)
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.shift_1px(dest)
return dest
end
local function active_cropping(x, y, size, p, tries)
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
local r = torch.uniform()
if p < r then
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
return xc, yc
else
local best_se = 0.0
local best_xc, best_yc
local m = torch.FloatTensor(x:size(1), size, size)
for i = 1, tries do
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local ycf = iproc.byte2float(yc)
local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
if se >= best_se then
best_xc = xcf
best_yc = ycf
best_se = se
end
end
return best_xc, best_yc
end
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
local down_scale = 1.0 / scale
local y = image.crop(src,
xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
end
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = {
"Box", -- 0.012756949974688
"Box","Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
@ -71,221 +82,173 @@ function pairwise_transform.scale(src, scale, size, offset, options)
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
local unstable_region_offset = 8
local downscale_filter = filters[torch.random(1, #filters)]
y = flip_augment(y)
if options.color_augment then
y = color_augment(y)
end
local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
x = iproc.scale(x, y:size(3), y:size(2))
y = y:float():div(255)
x = x:float():div(255)
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
return x, y
end
function pairwise_transform.jpeg_(src, quality, size, offset, options)
options = options or {color_augment = true, random_half = true, rgb = true}
if options.random_half then
src = random_half(src)
end
local yi = torch.random(0, src:size(2) - size - 1)
local xi = torch.random(0, src:size(3) - size - 1)
local y = src
local x
if options.color_augment then
y = color_augment(y)
end
x = y
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
y = image.crop(y, xi, yi, xi + size, yi + size)
x = image.crop(x, xi, yi, xi + size, yi + size)
y = y:float():div(255)
x = x:float():div(255)
x, y = flip_augment(x, y)
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
return x, image.crop(y, offset, offset, size - offset, size - offset)
end
function pairwise_transform.jpeg(src, level, size, offset, options)
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset,
options)
elseif level == 2 then
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset,
options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset,
options)
else
local quality1 = torch.random(52, 70)
return pairwise_transform.jpeg_(src,
{quality1,
quality1 - torch.random(5, 15),
quality1 - torch.random(15, 25)},
size, offset,
options)
end
else
error("unknown noise level: " .. level)
end
end
function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
if options.random_half then
src = random_half(src)
end
local y = preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
local down_scale = 1.0 / scale
local filters = {
"Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
local downscale_filter = filters[torch.random(1, #filters)]
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
local y = src
local x
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downscale_filter),
y:size(3), y:size(2))
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
if options.color_augment then
y = color_augment(y)
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y,
size,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
x = y
x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
return batch
end
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
local unstable_region_offset = 8
local y = preprocess(src, size, options)
local x = y
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
x:samplingFactors({1.0, 1.0, 1.0})
x:format("jpeg"):depth(8)
if options.jpeg_sampling_factors == 444 then
x:samplingFactors({1.0, 1.0, 1.0})
else -- 420
x:samplingFactors({2.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
x = iproc.scale(x, y:size(3), y:size(2))
y = image.crop(y,
xi, yi,
xi + size, yi + size)
x = image.crop(x,
xi, yi,
xi + size, yi + size)
x = x:float():div(255)
y = y:float():div(255)
x, y = flip_augment(x, y)
if options.rgb then
else
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
end
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
return x, image.crop(y, offset, offset, size - offset, size - offset)
end
function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
options = options or {color_augment = true, random_half = true}
if level == 1 then
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
size, offset, options)
elseif level == 2 then
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
size, offset, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
size, offset, options)
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y, size,
options.active_cropping_rate,
options.active_cropping_tries)
xc = iproc.byte2float(xc)
yc = iproc.byte2float(yc)
if options.rgb then
else
local quality1 = torch.random(52, 70)
return pairwise_transform.jpeg_scale_(src, scale,
{quality1,
quality1 - torch.random(5, 15),
quality1 - torch.random(15, 25)},
size, offset, options)
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
end
if torch.uniform() < options.nr_rate then
-- reducing noise
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
else
-- ratain useful details
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
end
end
return batch
end
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if style == "art" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset, n, options)
elseif level == 2 then
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
size, offset, n, options)
elseif r > 0.3 then
local quality1 = torch.random(37, 70)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
else
local quality1 = torch.random(52, 70)
local quality2 = quality1 - torch.random(5, 15)
local quality3 = quality1 - torch.random(15, 25)
return pairwise_transform.jpeg_(src,
{quality1, quality2, quality3},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
elseif style == "photo" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(30, 75)},
size, offset, n,
options)
elseif level == 2 then
if torch.uniform() > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(30, 60)},
size, offset, n, options)
else
local quality1 = torch.random(40, 60)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
else
error("unknown noise level: " .. level)
error("unknown style: " .. style)
end
end
local function test_jpeg()
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false)
image.display({image = y, legend = "y:0"})
image.display({image = x, legend = "x:0"})
for i = 2, 9 do
local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src),
{i * 10}, 128, 0, {color_augment = false, random_half = true})
image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
--print(x:mean(), y:mean())
function pairwise_transform.test_jpeg(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
for i = 1, 9 do
local xy = pairwise_transform.jpeg(src,
"art",
torch.random(1, 2),
128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
end
end
function pairwise_transform.test_scale(src)
torch.setdefaulttensortype("torch.FloatTensor")
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,
rgb = true
}
local image = require 'image'
local src = image.lena()
local function test_scale()
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
for i = 1, 9 do
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
--print(x:mean(), y:mean())
for i = 1, 10 do
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
end
end
local function test_jpeg_scale()
local loader = require './image_loader'
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
for i = 1, 9 do
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_augment = true, random_half = true})
image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
--print(x:mean(), y:mean())
end
for i = 1, 9 do
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_augment = true, random_half = true})
image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
print(y:size(), x:size())
--print(x:mean(), y:mean())
end
end
--test_scale()
--test_jpeg()
--test_jpeg_scale()
return pairwise_transform

View file

@ -1,4 +0,0 @@
require 'torch'
require 'cutorch'
require 'nn'
require 'cunn'

View file

@ -1,5 +1,5 @@
require 'image'
local iproc = require './iproc'
local iproc = require 'iproc'
local function reconstruct_y(model, x, offset, block_size)
if x:dim() == 2 then
@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size)
end
return new_x
end
function model_is_rgb(model)
local reconstruct = {}
function reconstruct.is_rgb(model)
if model:get(model:size() - 1).weight:size(1) == 3 then
-- 3ch RGB
return true
@ -57,8 +58,23 @@ function model_is_rgb(model)
return false
end
end
local reconstruct = {}
function reconstruct.offset_size(model)
local conv = model:findModules("nn.SpatialConvolutionMM")
if #conv > 0 then
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
else
conv = model:findModules("cudnn.SpatialConvolution")
local offset = 0
for i = 1, #conv do
offset = offset + (conv[i].kW - 1) / 2
end
return math.floor(offset)
end
end
function reconstruct.image_y(model, x, offset, block_size)
block_size = block_size or 128
local output_size = block_size - offset * 2
@ -78,7 +94,7 @@ function reconstruct.image_y(model, x, offset, block_size)
y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1
yuv[1]:copy(y)
local output = image.yuv2rgb(image.crop(yuv,
local output = image.yuv2rgb(iproc.crop(yuv,
pad_w1, pad_h1,
yuv:size(3) - pad_w2, yuv:size(2) - pad_h2))
output[torch.lt(output, 0)] = 0
@ -110,7 +126,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size)
y[torch.lt(y, 0)] = 0
y[torch.gt(y, 1)] = 1
yuv_jinc[1]:copy(y)
local output = image.yuv2rgb(image.crop(yuv_jinc,
local output = image.yuv2rgb(iproc.crop(yuv_jinc,
pad_w1, pad_h1,
yuv_jinc:size(3) - pad_w2, yuv_jinc:size(2) - pad_h2))
output[torch.lt(output, 0)] = 0
@ -135,7 +151,7 @@ function reconstruct.image_rgb(model, x, offset, block_size)
local pad_w2 = (w - offset) - x:size(3)
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
local y = reconstruct_rgb(model, input, offset, block_size)
local output = image.crop(y,
local output = iproc.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
collectgarbage()
@ -162,7 +178,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
local pad_w2 = (w - offset) - x:size(3)
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
local y = reconstruct_rgb(model, input, offset, block_size)
local output = image.crop(y,
local output = iproc.crop(y,
pad_w1, pad_h1,
y:size(3) - pad_w2, y:size(2) - pad_h2)
output[torch.lt(output, 0)] = 0
@ -172,18 +188,81 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
return output
end
function reconstruct.image(model, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.image_rgb(model, x, offset, block_size)
function reconstruct.image(model, x, block_size)
if reconstruct.is_rgb(model) then
return reconstruct.image_rgb(model, x,
reconstruct.offset_size(model), block_size)
else
return reconstruct.image_y(model, x, offset, block_size)
return reconstruct.image_y(model, x,
reconstruct.offset_size(model), block_size)
end
end
function reconstruct.scale(model, scale, x, offset, block_size)
if model_is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
function reconstruct.scale(model, scale, x, block_size)
if reconstruct.is_rgb(model) then
return reconstruct.scale_rgb(model, scale, x,
reconstruct.offset_size(model), block_size)
else
return reconstruct.scale_y(model, scale, x, offset, block_size)
return reconstruct.scale_y(model, scale, x,
reconstruct.offset_size(model), block_size)
end
end
local function tta(f, model, x, block_size)
local average = nil
local offset = reconstruct.offset_size(model)
for i = 1, 4 do
local flip_f, iflip_f
if i == 1 then
flip_f = function (a) return a end
iflip_f = function (a) return a end
elseif i == 2 then
flip_f = image.vflip
iflip_f = image.vflip
elseif i == 3 then
flip_f = image.hflip
iflip_f = image.hflip
elseif i == 4 then
flip_f = function (a) return image.hflip(image.vflip(a)) end
iflip_f = function (a) return image.vflip(image.hflip(a)) end
end
for j = 1, 2 do
local tr_f, itr_f
if j == 1 then
tr_f = function (a) return a end
itr_f = function (a) return a end
elseif j == 2 then
tr_f = function(a) return a:transpose(2, 3):contiguous() end
itr_f = function(a) return a:transpose(2, 3):contiguous() end
end
local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
offset, block_size)))
if not average then
average = out
else
average:add(out)
end
end
end
return average:div(8.0)
end
function reconstruct.image_tta(model, x, block_size)
if reconstruct.is_rgb(model) then
return tta(reconstruct.image_rgb, model, x, block_size)
else
return tta(reconstruct.image_y, model, x, block_size)
end
end
function reconstruct.scale_tta(model, scale, x, block_size)
if reconstruct.is_rgb(model) then
local f = function (model, x, offset, block_size)
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
end
return tta(f, model, x, block_size)
else
local f = function (model, x, offset, block_size)
return reconstruct.scale_y(model, scale, x, offset, block_size)
end
return tta(f, model, x, block_size)
end
end

View file

@ -1,5 +1,6 @@
require 'xlua'
require 'pl'
require 'trepl'
-- global settings
@ -14,22 +15,34 @@ local settings = {}
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x")
cmd:text("waifu2x-training")
cmd:text("Options:")
cmd:option("-seed", 11, 'fixed input seed')
cmd:option("-data_dir", "./data", 'data directory')
cmd:option("-test", "images/miku_small.png", 'test image file')
cmd:option("-gpu", -1, 'GPU Device ID')
cmd:option("-seed", 11, 'RNG seed')
cmd:option("-data_dir", "./data", 'path to data directory')
cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'path to test image')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-scale", 2.0, 'scale')
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
cmd:option("-crop_size", 128, 'crop size')
cmd:option("-batch_size", 2, 'mini batch size')
cmd:option("-epoch", 200, 'epoch')
cmd:option("-core", 2, 'cpu core')
cmd:option("-crop_size", 46, 'crop size')
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
cmd:option("-batch_size", 8, 'mini batch size')
cmd:option("-epoch", 200, 'number of total epochs to run')
cmd:option("-thread", -1, 'number of CPU threads')
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
local opt = cmd:parse(arg)
for k, v in pairs(opt) do
@ -53,26 +66,16 @@ end
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
error("scale must be mod-2")
end
if settings.random_half == 1 then
settings.random_half = true
else
settings.random_half = false
if not (settings.style == "art" or
settings.style == "photo") then
error(string.format("unknown style: %s", settings.style))
end
if settings.thread > 0 then
torch.setnumthreads(tonumber(settings.thread))
end
torch.setnumthreads(settings.core)
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
settings.validation_ratio = 0.1
settings.validation_crops = 40
local srcnn = require './srcnn'
if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
settings.create_model = srcnn.waifu4x
settings.block_offset = 13
else
settings.create_model = srcnn.waifu2x
settings.block_offset = 7
end
return settings

View file

@ -1,74 +1,68 @@
require './LeakyReLU'
require 'w2nn'
function nn.SpatialConvolutionMM:reset(stdv)
stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane))
self.weight:normal(0, stdv)
self.bias:fill(0)
end
-- ref: http://arxiv.org/abs/1502.01852
-- ref: http://arxiv.org/abs/1501.00092
local srcnn = {}
function srcnn.waifu2x(color)
function srcnn.channels(model)
return model:get(model:size() - 1).weight:size(1)
end
function srcnn.waifu2x_cunn(ch)
local model = nn.Sequential()
local ch = nil
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
if color then
error("unknown color: " .. color)
else
error("unknown color: nil")
end
end
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(w2nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model, 7
return model
end
-- current 4x is worse than 2x * 2
function srcnn.waifu4x(color)
function srcnn.waifu2x_cudnn(ch)
local model = nn.Sequential()
local ch = nil
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
model:add(w2nn.LeakyReLU(0.1))
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
--model:cuda()
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
return model
end
function srcnn.create(model_name, backend, color)
local ch = 3
if color == "rgb" then
ch = 3
elseif color == "y" then
ch = 1
else
error("unknown color: " .. color)
error("unsupported color: " + color)
end
if backend == "cunn" then
return srcnn.waifu2x_cunn(ch)
elseif backend == "cudnn" then
return srcnn.waifu2x_cudnn(ch)
else
error("unsupported backend: " + backend)
end
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
model:add(nn.View(-1):setNumInputDims(3))
return model, 13
end
return srcnn

26
lib/w2nn.lua Normal file
View file

@ -0,0 +1,26 @@
local function load_nn()
require 'torch'
require 'nn'
end
local function load_cunn()
require 'cutorch'
require 'cunn'
end
local function load_cudnn()
require 'cudnn'
cudnn.benchmark = true
end
if w2nn then
return w2nn
else
pcall(load_cunn)
pcall(load_cudnn)
w2nn = {}
require 'LeakyReLU'
require 'LeakyReLU_deprecated'
require 'DepthExpand2x'
require 'WeightedMSECriterion'
require 'ClippedWeightedHuberCriterion'
require 'cleanup_model'
return w2nn
end

Binary file not shown.

Binary file not shown.

169
tools/benchmark.lua Normal file
View file

@ -0,0 +1,169 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'xlua'
require 'w2nn'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
local gm = require 'graphicsmagick'
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-benchmark")
cmd:text("Options:")
cmd:option("-dir", "./data/test", 'test image directory')
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
cmd:option("-model2_dir", "", 'model2 directory (optional)')
cmd:option("-method", "scale", '(scale|noise)')
cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
cmd:option("-color", "rgb", '(rgb|y)')
cmd:option("-noise_level", 1, 'model noise level')
cmd:option("-jpeg_quality", 75, 'jpeg quality')
cmd:option("-jpeg_times", 1, 'jpeg compression times')
cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
local opt = cmd:parse(arg)
torch.setdefaulttensortype('torch.FloatTensor')
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
local function MSE(x1, x2)
return (x1 - x2):pow(2):mean()
end
local function YMSE(x1, x2)
local x1_2 = image.rgb2y(x1)
local x2_2 = image.rgb2y(x2)
return (x1_2 - x2_2):pow(2):mean()
end
local function PSNR(x1, x2)
local mse = MSE(x1, x2)
return 10 * math.log10(1.0 / mse)
end
local function YPSNR(x1, x2)
local mse = YMSE(x1, x2)
return 10 * math.log10(1.0 / mse)
end
local function transform_jpeg(x, opt)
for i = 1, opt.jpeg_times do
jpeg = gm.Image(x, "RGB", "DHW")
jpeg:format("jpeg")
jpeg:samplingFactors({1.0, 1.0, 1.0})
blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
jpeg:fromBlob(blob, len)
x = jpeg:toTensor("byte", "RGB", "DHW")
end
return x
end
local function transform_scale(x, opt)
return iproc.scale(x,
x:size(3) * 0.5,
x:size(2) * 0.5,
opt.filter)
end
local function benchmark(opt, x, input_func, model1, model2)
local model1_mse = 0
local model2_mse = 0
local model1_psnr = 0
local model2_psnr = 0
for i = 1, #x do
local ground_truth = x[i]
local input, model1_output, model2_output
input = input_func(ground_truth, opt)
input = input:float():div(255)
ground_truth = ground_truth:float():div(255)
t = sys.clock()
if input:size(3) == ground_truth:size(3) then
model1_output = reconstruct.image(model1, input)
if model2 then
model2_output = reconstruct.image(model2, input)
end
else
model1_output = reconstruct.scale(model1, 2.0, input)
if model2 then
model2_output = reconstruct.scale(model2, 2.0, input)
end
end
if opt.color == "y" then
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
if model2 then
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
end
elseif opt.color == "rgb" then
model1_mse = model1_mse + MSE(ground_truth, model1_output)
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
if model2 then
model2_mse = model2_mse + MSE(ground_truth, model2_output)
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
end
else
error("Unknown color: " .. opt.color)
end
if model2 then
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
model1_mse / i, model2_mse / i,
model1_psnr / i, model2_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
i, #x,
model1_mse / i, model1_psnr / i
))
end
io.stdout:flush()
end
io.stdout:write("\n")
end
local function load_data(test_dir)
local test_x = {}
local files = dir.getfiles(test_dir, "*.*")
for i = 1, #files do
table.insert(test_x, iproc.crop_mod4(image_loader.load_byte(files[i])))
xlua.progress(i, #files)
end
return test_x
end
function load_model(filename)
return torch.load(filename, "ascii")
end
print(opt)
if opt.method == "scale" then
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
local s1, model1 = pcall(load_model, f1)
local s2, model2 = pcall(load_model, f2)
if not s1 then
error("Load error: " .. f1)
end
if not s2 then
model2 = nil
end
local test_x = load_data(opt.dir)
benchmark(opt, test_x, transform_scale, model1, model2)
elseif opt.method == "noise" then
local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
local s1, model1 = pcall(load_model, f1)
local s2, model2 = pcall(load_model, f2)
if not s1 then
error("Load error: " .. f1)
end
if not s2 then
model2 = nil
end
local test_x = load_data(opt.dir)
benchmark(opt, test_x, transform_jpeg, model1, model2)
end

25
tools/cleanup_model.lua Normal file
View file

@ -0,0 +1,25 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'w2nn'
torch.setdefaulttensortype("torch.FloatTensor")
local cmd = torch.CmdLine()
cmd:text()
cmd:text("cleanup model")
cmd:text("Options:")
cmd:option("-model", "./model.t7", 'path of model file')
cmd:option("-iformat", "binary", 'input format')
cmd:option("-oformat", "binary", 'output format')
local opt = cmd:parse(arg)
local model = torch.load(opt.model, opt.iformat)
if model then
w2nn.cleanup_model(model)
model:cuda()
model:evaluate()
torch.save(opt.model, model, opt.oformat)
else
error("model not found")
end

43
tools/cudnn2cunn.lua Normal file
View file

@ -0,0 +1,43 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'os'
require 'w2nn'
local srcnn = require 'srcnn'
local function cudnn2cunn(cudnn_model)
local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
assert(#weight_from == #weight_to)
for i = 1, #weight_from do
local from = weight_from[i]
local to = weight_to[i]
to.weight:copy(from.weight)
to.bias:copy(from.bias)
end
cunn_model:cuda()
cunn_model:evaluate()
return cunn_model
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x cudnn model to cunn model converter")
cmd:text("Options:")
cmd:option("-i", "", 'Specify the input cunn model')
cmd:option("-o", "", 'Specify the output cudnn model')
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
local opt = cmd:parse(arg)
if not path.isfile(opt.i) then
cmd:help()
os.exit(-1)
end
local cudnn_model = torch.load(opt.i, opt.iformat)
local cunn_model = cudnn2cunn(cudnn_model)
torch.save(opt.o, cunn_model, opt.oformat)

43
tools/cunn2cudnn.lua Normal file
View file

@ -0,0 +1,43 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'os'
require 'w2nn'
local srcnn = require 'srcnn'
local function cunn2cudnn(cunn_model)
local cudnn_model = srcnn.waifu2x_cudnn(srcnn.channels(cunn_model))
local weight_from = cunn_model:findModules("nn.SpatialConvolutionMM")
local weight_to = cudnn_model:findModules("cudnn.SpatialConvolution")
assert(#weight_from == #weight_to)
for i = 1, #weight_from do
local from = weight_from[i]
local to = weight_to[i]
to.weight:copy(from.weight)
to.bias:copy(from.bias)
end
cudnn_model:cuda()
cudnn_model:evaluate()
return cudnn_model
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x cunn model to cudnn model converter")
cmd:text("Options:")
cmd:option("-i", "", 'Specify the input cudnn model')
cmd:option("-o", "", 'Specify the output cunn model')
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
local opt = cmd:parse(arg)
if not path.isfile(opt.i) then
cmd:help()
os.exit(-1)
end
local cunn_model = torch.load(opt.i, opt.iformat)
local cudnn_model = cunn2cudnn(cunn_model)
torch.save(opt.o, cudnn_model, opt.oformat)

54
tools/export_model.lua Normal file
View file

@ -0,0 +1,54 @@
-- adapted from https://github.com/marcan/cl-waifu2x
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'w2nn'
local cjson = require "cjson"
function export(model, output)
local jmodules = {}
local modules = model:findModules("nn.SpatialConvolutionMM")
if #modules == 0 then
-- cudnn model
modules = model:findModules("cudnn.SpatialConvolution")
end
for i = 1, #modules, 1 do
local module = modules[i]
local jmod = {
kW = module.kW,
kH = module.kH,
nInputPlane = module.nInputPlane,
nOutputPlane = module.nOutputPlane,
bias = torch.totable(module.bias:float()),
weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
}
table.insert(jmodules, jmod)
end
jmodules[1].color = "RGB"
jmodules[1].gamma = 0
jmodules[#jmodules].color = "RGB"
jmodules[#jmodules].gamma = 0
local fp = io.open(output, "w")
if not fp then
error("IO Error: " .. output)
end
fp:write(cjson.encode(jmodules))
fp:close()
end
local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x export model")
cmd:text("Options:")
cmd:option("-i", "input.t7", 'Specify the input torch model')
cmd:option("-o", "output.json", 'Specify the output json file')
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
local opt = cmd:parse(arg)
if not path.isfile(opt.i) then
cmd:help()
os.exit(-1)
end
local model = torch.load(opt.i, opt.iformat)
export(model, opt.o)

175
train.lua
View file

@ -1,21 +1,25 @@
require './lib/portable'
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
require 'optim'
require 'xlua'
require 'pl'
local settings = require './lib/settings'
local minibatch_adam = require './lib/minibatch_adam'
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local pairwise_transform = require './lib/pairwise_transform'
local image_loader = require './lib/image_loader'
require 'w2nn'
local settings = require 'settings'
local srcnn = require 'srcnn'
local minibatch_adam = require 'minibatch_adam'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local compression = require 'compression'
local pairwise_transform = require 'pairwise_transform'
local image_loader = require 'image_loader'
local function save_test_scale(model, rgb, file)
local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset)
local up = reconstruct.scale(model, settings.scale, rgb)
image.save(file, up)
end
local function save_test_jpeg(model, rgb, file)
local im, count = reconstruct.image(model, rgb, settings.block_offset)
local im, count = reconstruct.image(model, rgb)
image.save(file, im)
end
local function split_data(x, test_size)
@ -31,14 +35,19 @@ local function split_data(x, test_size)
end
return train_x, valid_x
end
local function make_validation_set(x, transformer, n)
local function make_validation_set(x, transformer, n, batch_size)
n = n or 4
local data = {}
for i = 1, #x do
for k = 1, n do
local x, y = transformer(x[i], true)
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
for k = 1, math.max(n / batch_size, 1) do
local xy = transformer(x[i], true, batch_size)
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
for j = 1, #xy do
tx[j]:copy(xy[j][1])
ty[j]:copy(xy[j][2])
end
table.insert(data, {x = tx, y = ty})
end
xlua.progress(i, #x)
collectgarbage()
@ -50,24 +59,92 @@ local function validate(model, criterion, data)
for i = 1, #data do
local z = model:forward(data[i].x:cuda())
loss = loss + criterion:forward(z, data[i].y:cuda())
xlua.progress(i, #data)
if i % 10 == 0 then
if i % 100 == 0 then
xlua.progress(i, #data)
collectgarbage()
end
end
xlua.progress(#data, #data)
return loss / #data
end
local function create_criterion(model)
if reconstruct.is_rgb(model) then
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(3, output_w * output_w)
weight[1]:fill(0.29891 * 3) -- R
weight[2]:fill(0.58661 * 3) -- G
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
return nn.MSECriterion():cuda()
end
end
local function transformer(x, is_validation, n, offset)
x = compression.decompress(x)
n = n or settings.batch_size;
if is_validation == nil then is_validation = false end
local random_color_noise_rate = nil
local random_overlay_rate = nil
local active_cropping_rate = nil
local active_cropping_tries = nil
if is_validation then
active_cropping_rate = 0
active_cropping_tries = 0
random_color_noise_rate = 0.0
random_overlay_rate = 0.0
else
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
random_color_noise_rate = settings.random_color_noise_rate
random_overlay_rate = settings.random_overlay_rate
end
if settings.method == "scale" then
return pairwise_transform.scale(x,
settings.scale,
settings.crop_size, offset,
n,
{
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
settings.style,
settings.noise_level,
settings.crop_size, offset,
n,
{
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
max_size = settings.max_size,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
nr_rate = settings.nr_rate,
rgb = (settings.color == "rgb")
})
end
end
local function train()
local model, offset = settings.create_model(settings.color)
assert(offset == settings.block_offset)
local criterion = nn.MSECriterion():cuda()
local model = srcnn.create(settings.method, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n)
return transformer(x, is_validation, n, offset)
end
local criterion = create_criterion(model)
local x = torch.load(settings.images)
local lrd_count = 0
local train_x, valid_x = split_data(x,
math.floor(settings.validation_ratio * #x),
settings.validation_crops)
local test = image_loader.load_float(settings.test)
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local adam_config = {
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
@ -78,38 +155,11 @@ local function train()
elseif settings.color == "rgb" then
ch = 3
end
local transformer = function(x, is_validation)
if is_validation == nil then is_validation = false end
if settings.method == "scale" then
return pairwise_transform.scale(x,
settings.scale,
settings.crop_size, offset,
{ color_augment = not is_validation,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
settings.noise_level,
settings.crop_size, offset,
{ color_augment = not is_validation,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
elseif settings.method == "noise_scale" then
return pairwise_transform.jpeg_scale(x,
settings.scale,
settings.noise_level,
settings.crop_size, offset,
{ color_augment = not is_validation,
random_half = settings.random_half,
rgb = (settings.color == "rgb")
})
end
end
local best_score = 100000.0
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, transformer, 20)
local valid_xy = make_validation_set(valid_x, pairwise_func,
settings.validation_crops,
settings.batch_size)
valid_x = nil
collectgarbage()
@ -119,7 +169,7 @@ local function train()
model:training()
print("# " .. epoch)
print(minibatch_adam(model, criterion, train_x, adam_config,
transformer,
pairwise_func,
{ch, settings.crop_size, settings.crop_size},
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
))
@ -127,6 +177,7 @@ local function train()
print("# validation")
local score = validate(model, criterion, valid_xy)
if score < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
best_score = score
print("* update best model")
@ -134,22 +185,17 @@ local function train()
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level))
save_test_jpeg(model, test, log)
save_test_jpeg(model, test_image, log)
elseif settings.method == "scale" then
local log = path.join(settings.model_dir,
("scale%.1f_best.png"):format(settings.scale))
save_test_scale(model, test, log)
elseif settings.method == "noise_scale" then
local log = path.join(settings.model_dir,
("noise%d_scale%.1f_best.png"):format(settings.noise_level,
settings.scale))
save_test_scale(model, test, log)
save_test_scale(model, test_image, log)
end
else
lrd_count = lrd_count + 1
if lrd_count > 5 then
lrd_count = 0
adam_config.learningRate = adam_config.learningRate * 0.8
adam_config.learningRate = adam_config.learningRate * 0.9
print("* learning rate decay: " .. adam_config.learningRate)
end
end
@ -157,6 +203,9 @@ local function train()
collectgarbage()
end
end
if settings.gpu > 0 then
cutorch.setDevice(settings.gpu)
end
torch.manualSeed(settings.seed)
cutorch.manualSeed(settings.seed)
print(settings)

View file

@ -1,10 +1,12 @@
#!/bin/sh
th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th convert_data.lua
th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii

9
train_ukbench.sh Executable file
View file

@ -0,0 +1,9 @@
#!/bin/sh
th convert_data.lua -data_dir ./data/ukbench
#th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.png -nr_rate 0.9 -jpeg_sampling_factors 420 # -thread 4 -backend cudnn
#th tools/cleanup_model.lua -model models/ukbench/noise2_model.t7 -oformat ascii
th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg # -thread 4 -backend cudnn
th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii

View file

@ -1,12 +1,11 @@
require './lib/portable'
require 'sys'
require 'pl'
require './lib/LeakyReLU'
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local BLOCK_OFFSET = 7
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
require 'sys'
require 'w2nn'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
torch.setdefaulttensortype('torch.FloatTensor')
@ -14,43 +13,109 @@ local function convert_image(opt)
local x, alpha = image_loader.load_float(opt.i)
local new_x = nil
local t = sys.clock()
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
else
scale_f = reconstruct.scale
image_f = reconstruct.image
end
if opt.o == "(auto)" then
local name = path.basename(opt.i)
local e = path.extension(name)
local base = name:sub(0, name:len() - e:len())
opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m))
end
if opt.m == "noise" then
local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
model:evaluate()
new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
local model = torch.load(model_path, "ascii")
if not model then
error("Load Error: " .. model_path)
end
new_x = image_f(model, x, opt.crop_size)
elseif opt.m == "scale" then
local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
model:evaluate()
new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
local model = torch.load(model_path, "ascii")
if not model then
error("Load Error: " .. model_path)
end
new_x = scale_f(model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" then
local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
noise_model:evaluate()
scale_model:evaluate()
x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
local noise_model = torch.load(noise_model_path, "ascii")
local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
local scale_model = torch.load(scale_model_path, "ascii")
if not noise_model then
error("Load Error: " .. noise_model_path)
end
if not scale_model then
error("Load Error: " .. scale_model_path)
end
x = image_f(noise_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
image_loader.save_png(opt.o, new_x, alpha)
if opt.white_noise == 1 then
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
end
image_loader.save_png(opt.o, new_x, alpha, opt.depth)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii")
local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii")
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
noise1_model:evaluate()
noise2_model:evaluate()
scale_model:evaluate()
local model_path, noise1_model, noise2_model, scale_model
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
image_f = reconstruct.image_tta
else
scale_f = reconstruct.scale
image_f = reconstruct.image
end
if opt.m == "scale" then
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
scale_model = torch.load(model_path, "ascii")
if not scale_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 1 then
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 2 then
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise_scale" then
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
scale_model = torch.load(model_path, "ascii")
if not scale_model then
error("Load Error: " .. model_path)
end
if opt.noise_level == 1 then
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.noise_level == 2 then
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
end
end
local fp = io.open(opt.l)
if not fp then
error("Open Error: " .. opt.l)
end
local count = 0
local lines = {}
for line in fp:lines() do
@ -62,20 +127,24 @@ local function convert_frames(opt)
local x, alpha = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
new_x = image_f(noise1_model, x, opt.crop_size)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = image_func(noise2_model, x, opt.crop_size)
elseif opt.m == "scale" then
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
x = image_f(noise1_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
x = image_f(noise2_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
else
error("undefined method:" .. opt.method)
end
if opt.white_noise == 1 then
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
end
local output = nil
if opt.o == "(auto)" then
local name = path.basename(lines[i])
@ -85,7 +154,7 @@ local function convert_frames(opt)
else
output = string.format(opt.o, i)
end
image_loader.save_png(output, new_x, alpha)
image_loader.save_png(output, new_x, alpha, opt.depth)
xlua.progress(i, #lines)
if i % 10 == 0 then
collectgarbage()
@ -101,17 +170,30 @@ local function waifu2x()
cmd:text()
cmd:text("waifu2x")
cmd:text("Options:")
cmd:option("-i", "images/miku_small.png", 'path of the input image')
cmd:option("-l", "", 'path of the image-list')
cmd:option("-i", "images/miku_small.png", 'path to input image')
cmd:option("-l", "", 'path to image-list.txt')
cmd:option("-scale", 2, 'scale factor')
cmd:option("-o", "(auto)", 'path of the output file')
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
cmd:option("-o", "(auto)", 'path to output file')
cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-resume", 0, "skip existing files (0|1)")
cmd:option("-thread", -1, "number of CPU threads")
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)')
cmd:option("-white_noise_std", 0.0055, 'standard division of white noise')
local opt = cmd:parse(arg)
if opt.thread > 0 then
torch.setnumthreads(opt.thread)
end
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
if string.len(opt.l) == 0 then
convert_image(opt)
else

209
web.lua
View file

@ -1,11 +1,21 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
local ROOT = path.dirname(__FILE__)
package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
_G.TURBO_SSL = true
local turbo = require 'turbo'
require 'w2nn'
local uuid = require 'uuid'
local ffi = require 'ffi'
local md5 = require 'md5'
require 'pl'
require './lib/portable'
require './lib/LeakyReLU'
local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
-- Notes: turbo and xlua has different implementation of string:split().
-- Therefore, string:split() has conflict issue.
-- In this script, use turbo's string:split().
local turbo = require 'turbo'
local cmd = torch.CmdLine()
cmd:text()
@ -13,24 +23,27 @@ cmd:text("waifu2x-api")
cmd:text("Options:")
cmd:option("-port", 8812, 'listen port')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-core", 2, 'number of CPU cores')
cmd:option("-thread", -1, 'number of CPU threads')
local opt = cmd:parse(arg)
cutorch.setDevice(opt.gpu)
torch.setdefaulttensortype('torch.FloatTensor')
torch.setnumthreads(opt.core)
local iproc = require './lib/iproc'
local reconstruct = require './lib/reconstruct'
local image_loader = require './lib/image_loader'
local MODEL_DIR = "./models/anime_style_art_rgb"
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local USE_CACHE = true
local CACHE_DIR = "./cache"
if opt.thread > 0 then
torch.setnumthreads(opt.thread)
end
if cudnn then
cudnn.fastest = true
cudnn.benchmark = false
end
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
--local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
--local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
--local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
local CACHE_DIR = path.join(ROOT, "cache")
local MAX_NOISE_IMAGE = 2560 * 2560
local MAX_SCALE_IMAGE = 1280 * 1280
local CURL_OPTIONS = {
@ -40,7 +53,6 @@ local CURL_OPTIONS = {
max_redirects = 2
}
local CURL_MAX_SIZE = 2 * 1024 * 1024
local BLOCK_OFFSET = 7 -- see srcnn.lua
local function valid_size(x, scale)
if scale == 0 then
@ -50,20 +62,16 @@ local function valid_size(x, scale)
end
end
local function get_image(req)
local file = req:get_argument("file", "")
local url = req:get_argument("url", "")
local blob = nil
local img = nil
local alpha = nil
if file and file:len() > 0 then
blob = file
img, alpha = image_loader.decode_float(blob)
elseif url and url:len() > 0 then
local function cache_url(url)
local hash = md5.sumhexa(url)
local cache_file = path.join(CACHE_DIR, "url_" .. hash)
if path.exists(cache_file) then
return image_loader.load_float(cache_file)
else
local res = coroutine.yield(
turbo.async.HTTPClient({verify_ca=false},
nil,
CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
nil,
CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
)
if res.code == 200 then
local content_type = res.headers:get("Content-Type", true)
@ -71,33 +79,64 @@ local function get_image(req)
content_type = content_type[1]
end
if content_type and content_type:find("image") then
blob = res.body
img, alpha = image_loader.decode_float(blob)
local fp = io.open(cache_file, "wb")
local blob = res.body
fp:write(blob)
fp:close()
return image_loader.decode_float(blob)
end
end
end
return img, blob, alpha
return nil, nil, nil
end
local function apply_denoise1(x)
return reconstruct.image(noise1_model, x, BLOCK_OFFSET)
local function get_image(req)
local file = req:get_argument("file", "")
local url = req:get_argument("url", "")
if file and file:len() > 0 then
return image_loader.decode_float(file)
elseif url and url:len() > 0 then
return cache_url(url)
end
return nil, nil, nil
end
local function apply_denoise2(x)
return reconstruct.image(noise2_model, x, BLOCK_OFFSET)
local function cleanup_model(model)
if CLEANUP_MODEL then
w2nn.cleanup_model(model) -- release GPU memory
end
end
local function apply_scale2x(x)
return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET)
end
local function cache_do(cache, x, func)
if path.exists(cache) then
return image.load(cache)
local function convert(x, options)
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
if path.exists(cache_file) then
return image.load(cache_file)
else
x = func(x)
image.save(cache, x)
if options.style == "art" then
if options.method == "scale" then
x = reconstruct.scale(art_scale2_model, 2.0, x)
cleanup_model(art_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(art_noise1_model, x)
cleanup_model(art_noise1_model)
else -- options.method == "noise2"
x = reconstruct.image(art_noise2_model, x)
cleanup_model(art_noise2_model)
end
else --[[photo
if options.method == "scale" then
x = reconstruct.scale(photo_scale2_model, 2.0, x)
cleanup_model(photo_scale2_model)
elseif options.method == "noise1" then
x = reconstruct.image(photo_noise1_model, x)
cleanup_model(photo_noise1_model)
elseif options.method == "noise2" then
x = reconstruct.image(photo_noise2_model, x)
cleanup_model(photo_noise2_model)
end
--]]
end
image.save(cache_file, x)
return x
end
end
local function client_disconnected(handler)
return not(handler.request and
handler.request.connection and
@ -112,63 +151,51 @@ function APIHandler:post()
self:write("client disconnected")
return
end
local x, src, alpha = get_image(self)
local x, alpha, blob = get_image(self)
local scale = tonumber(self:get_argument("scale", "0"))
local noise = tonumber(self:get_argument("noise", "0"))
local white_noise = tonumber(self:get_argument("white_noise", "0"))
local style = self:get_argument("style", "art")
if style ~= "art" then
style = "photo" -- style must be art or photo
end
if x and valid_size(x, scale) then
if USE_CACHE and (noise ~= 0 or scale ~= 0) then
local hash = md5.sumhexa(src)
local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
if (noise ~= 0 or scale ~= 0) then
local hash = md5.sumhexa(blob)
if noise == 1 then
x = cache_do(cache_noise1, x, apply_denoise1)
x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
elseif noise == 2 then
x = cache_do(cache_noise2, x, apply_denoise2)
x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
end
if scale == 1 or scale == 2 then
if noise == 1 then
x = cache_do(cache_noise1_scale, x, apply_scale2x)
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
elseif noise == 2 then
x = cache_do(cache_noise2_scale, x, apply_scale2x)
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
else
x = cache_do(cache_scale, x, apply_scale2x)
x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
end
if scale == 1 then
x = iproc.scale(x,
math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
"Jinc")
x = iproc.scale_with_gamma22(x,
math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
"Jinc")
end
end
elseif noise ~= 0 or scale ~= 0 then
if noise == 1 then
x = apply_denoise1(x)
elseif noise == 2 then
x = apply_denoise2(x)
end
if scale == 1 then
local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
x = apply_scale2x(x)
x = iproc.scale(x, x16[1], x16[2], "Jinc")
elseif scale == 2 then
x = apply_scale2x(x)
if white_noise == 1 then
x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
end
end
local name = uuid() .. ".png"
local blob, len = image_loader.encode_png(x, alpha)
local blob = image_loader.encode_png(x, alpha)
self:set_header("Content-Disposition", string.format('filename="%s"', name))
self:set_header("Content-Type", "image/png")
self:set_header("Content-Length", string.format("%d", len))
self:write(ffi.string(blob, len))
self:set_header("Content-Length", string.format("%d", #blob))
self:write(blob)
else
if not x then
self:set_status(400)
self:write("ERROR: unsupported image format.")
self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
else
self:set_status(400)
self:write("ERROR: image size exceeds maximum allowable size.")
@ -177,9 +204,9 @@ function APIHandler:post()
collectgarbage()
end
local FormHandler = class("FormHandler", turbo.web.RequestHandler)
local index_ja = file.read("./assets/index.ja.html")
local index_ru = file.read("./assets/index.ru.html")
local index_en = file.read("./assets/index.html")
local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
local index_en = file.read(path.join(ROOT, "assets", "index.html"))
function FormHandler:get()
local lang = self.request.headers:get("Accept-Language")
if lang then
@ -209,9 +236,11 @@ turbo.log.categories = {
local app = turbo.web.Application:new(
{
{"^/$", FormHandler},
{"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
{"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
{"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")},
{"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
{"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
{"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
{"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
{"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
{"^/api$", APIHandler},
}
)