2
.gitattributes
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
models/*/*.json binary
|
||||
*.t7 binary
|
13
.gitignore
vendored
|
@ -1,4 +1,15 @@
|
|||
*~
|
||||
work/
|
||||
cache/*.png
|
||||
models/*.png
|
||||
cache/url_*
|
||||
data/
|
||||
!data/.gitkeep
|
||||
|
||||
models/
|
||||
!models/anime_style_art
|
||||
!models/anime_style_art_rgb
|
||||
!models/ukbench
|
||||
models/*/*.png
|
||||
|
||||
waifu2x.log
|
||||
|
||||
|
|
54
README.md
|
@ -19,16 +19,11 @@ waifu2x is inspired by SRCNN [1]. 2D character picture (HatsuneMiku) is licensed
|
|||
|
||||
## Public AMI
|
||||
```
|
||||
AMI ID: ami-0be01e4f
|
||||
AMI NAME: waifu2x-server
|
||||
Instance Type: g2.2xlarge
|
||||
Region: US West (N.California)
|
||||
OS: Ubuntu 14.04
|
||||
User: ubuntu
|
||||
Created at: 2015-08-12
|
||||
TODO
|
||||
```
|
||||
|
||||
## Third Party Software
|
||||
|
||||
[Third-Party](https://github.com/nagadomi/waifu2x/wiki/Third-Party)
|
||||
|
||||
## Dependencies
|
||||
|
@ -37,10 +32,12 @@ Created at: 2015-08-12
|
|||
- NVIDIA GPU
|
||||
|
||||
### Platform
|
||||
|
||||
- [Torch7](http://torch.ch/)
|
||||
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-toolkit)
|
||||
|
||||
### lualocks packages (excludes torch7's default packages)
|
||||
- lua-csnappy
|
||||
- md5
|
||||
- uuid
|
||||
- [turbo](https://github.com/kernelsauce/turbo)
|
||||
|
@ -57,34 +54,44 @@ See: [NVIDIA CUDA Getting Started Guide for Linux](http://docs.nvidia.com/cuda/c
|
|||
Download [CUDA](http://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
```
|
||||
sudo dpkg -i cuda-repo-ubuntu1404_7.0-28_amd64.deb
|
||||
sudo dpkg -i cuda-repo-ubuntu1404_7.5-18_amd64.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda
|
||||
```
|
||||
|
||||
#### Install Package
|
||||
|
||||
```
|
||||
sudo apt-get install libsnappy-dev
|
||||
```
|
||||
|
||||
#### Install Torch7
|
||||
|
||||
See: [Getting started with Torch](http://torch.ch/docs/getting-started.html)
|
||||
|
||||
And install luarocks packages.
|
||||
```
|
||||
luarocks install graphicsmagick # upgrade
|
||||
luarocks install lua-csnappy
|
||||
luarocks install md5
|
||||
luarocks install uuid
|
||||
PREFIX=$HOME/torch/install luarocks install turbo # if you need to use web application
|
||||
```
|
||||
|
||||
#### Getting waifu2x
|
||||
|
||||
```
|
||||
git clone --depth 1 https://github.com/nagadomi/waifu2x.git
|
||||
```
|
||||
|
||||
#### Validation
|
||||
|
||||
Test the waifu2x command line tool.
|
||||
Testing the waifu2x command line tool.
|
||||
```
|
||||
th waifu2x.lua
|
||||
```
|
||||
|
||||
### Setting Up the Web Application Environment (if you needed)
|
||||
|
||||
#### Install packages
|
||||
|
||||
```
|
||||
luarocks install md5
|
||||
luarocks install uuid
|
||||
PREFIX=$HOME/torch/install luarocks install turbo
|
||||
```
|
||||
|
||||
## Web Application
|
||||
Run.
|
||||
```
|
||||
th web.lua
|
||||
```
|
||||
|
@ -114,11 +121,11 @@ th waifu2x.lua -m noise_scale -noise_level 1 -i input_image.png -o output_image.
|
|||
th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.png
|
||||
```
|
||||
|
||||
See also `images/gen.sh`.
|
||||
See also `th waifu2x.lua -h`.
|
||||
|
||||
### Video Encoding
|
||||
|
||||
\* `avconv` is `ffmpeg` on Ubuntu 14.04.
|
||||
\* `avconv` is alias of `ffmpeg` on Ubuntu 14.04.
|
||||
|
||||
Extracting images and audio from a video. (range: 00:09:00 ~ 00:12:00)
|
||||
```
|
||||
|
@ -144,6 +151,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
|
|||
```
|
||||
|
||||
## Training Your Own Model
|
||||
Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
|
||||
|
||||
### Data Preparation
|
||||
|
||||
|
@ -151,7 +159,7 @@ Genrating a file list.
|
|||
```
|
||||
find /path/to/image/dir -name "*.png" > data/image_list.txt
|
||||
```
|
||||
(You should use PNG! In my case, waifu2x is trained with 3000 high-resolution-noise-free-PNG images.)
|
||||
You should use noise free images. In my case, waifu2x is trained with 6000 high-resolution-noise-free-PNG images.
|
||||
|
||||
Converting training data.
|
||||
```
|
||||
|
|
|
@ -3,7 +3,15 @@ require 'pl'
|
|||
CACHE_DIR="cache"
|
||||
TTL = 3600 * 24
|
||||
|
||||
local files = dir.getfiles(CACHE_DIR, "*.png")
|
||||
local files = {}
|
||||
local image_cache = dir.getfiles(CACHE_DIR, "*.png")
|
||||
local url_cache = dir.getfiles(CACHE_DIR, "url_*")
|
||||
for i = 1, #image_cache do
|
||||
table.insert(files, image_cache[i])
|
||||
end
|
||||
for i = 1, #url_cache do
|
||||
table.insert(files, url_cache[i])
|
||||
end
|
||||
local now = os.time()
|
||||
for i, f in pairs(files) do
|
||||
if now - path.getmtime(f) > TTL then
|
||||
|
|
|
@ -2,51 +2,17 @@
|
|||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<link rel="canonical" href="http://waifu2x.udp.jp/">
|
||||
<title>waifu2x</title>
|
||||
<style type="text/css">
|
||||
body {
|
||||
margin: 1em 2em 1em 2em;
|
||||
background: LightGray;
|
||||
width: 640px;
|
||||
}
|
||||
fieldset {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.about {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
font-size: 0.9em;
|
||||
padding: 1em 5px 0.2em 0;
|
||||
}
|
||||
.help {
|
||||
font-size: 0.85em;
|
||||
margin: 1em 0 0 0;
|
||||
}
|
||||
</style>
|
||||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript">
|
||||
function clear_file() {
|
||||
var new_file = $("#file").clone();
|
||||
new_file.change(clear_url);
|
||||
$("#file").replaceWith(new_file);
|
||||
}
|
||||
function clear_url() {
|
||||
$("#url").val("")
|
||||
}
|
||||
$(function (){
|
||||
$("#url").change(clear_file);
|
||||
$("#file").change(clear_url);
|
||||
})
|
||||
</script>
|
||||
<script type="text/javascript" src="ui.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>waifu2x</h1>
|
||||
<div class="header">
|
||||
<div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
|
||||
<img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
|
||||
<a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
|
||||
<div class="github-banner">
|
||||
<img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
|
||||
<a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
|
||||
</div>
|
||||
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
|
||||
</div>
|
||||
|
@ -66,12 +32,14 @@
|
|||
Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<fieldset class="noise-field">
|
||||
<legend>Noise Reduction (expect JPEG Artifact)</legend>
|
||||
<label><input type="radio" name="noise" value="0"> None</label>
|
||||
<label><input type="radio" name="noise" value="1" checked="checked"> Medium</label>
|
||||
<label><input type="radio" name="noise" value="2"> High</label>
|
||||
<div class="help">When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.</div>
|
||||
<div class="help">
|
||||
When using 2x scaling, we never recommend to use high level of noise reduction, it almost always makes image worse, it makes sense for only some rare cases when image had really bad quality from the beginning.
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<legend>Upscaling</legend>
|
||||
|
@ -82,7 +50,7 @@
|
|||
<input type="submit"/>
|
||||
</form>
|
||||
<div class="help">
|
||||
<ul style="padding-left: 15px;">
|
||||
<ul class="padding-left">
|
||||
<li>If you are using Firefox, Please press the CTRL+S key to save image. "Save Image" option doesn't work.
|
||||
</ul>
|
||||
</div>
|
||||
|
|
|
@ -2,51 +2,17 @@
|
|||
<html lang="ja">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<link rel="canonical" href="http://waifu2x.udp.jp/">
|
||||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<title>waifu2x</title>
|
||||
<style type="text/css">
|
||||
body {
|
||||
margin: 1em 2em 1em 2em;
|
||||
background: LightGray;
|
||||
width: 640px;
|
||||
}
|
||||
fieldset {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.about {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
font-size: 0.8em;
|
||||
padding: 1em 5px 0.2em 0;
|
||||
}
|
||||
.help {
|
||||
font-size: 0.8em;
|
||||
margin: 1em 0 0 0;
|
||||
}
|
||||
</style>
|
||||
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript">
|
||||
function clear_file() {
|
||||
var new_file = $("#file").clone();
|
||||
new_file.change(clear_url);
|
||||
$("#file").replaceWith(new_file);
|
||||
}
|
||||
function clear_url() {
|
||||
$("#url").val("")
|
||||
}
|
||||
$(function (){
|
||||
$("#url").change(clear_file);
|
||||
$("#file").change(clear_url);
|
||||
})
|
||||
</script>
|
||||
<script type="text/javascript" src="ui.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>waifu2x</h1>
|
||||
<div class="header">
|
||||
<div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
|
||||
<img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
|
||||
<a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
|
||||
<div class="github-banner">
|
||||
<img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
|
||||
<a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
|
||||
</div>
|
||||
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
|
||||
</div>
|
||||
|
@ -66,7 +32,7 @@
|
|||
制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<fieldset class="noise-field">
|
||||
<legend>ノイズ除去 (JPEGノイズを想定)</legend>
|
||||
<label><input type="radio" name="noise" value="0"> なし</label>
|
||||
<label><input type="radio" name="noise" value="1" checked="checked"> 弱</label>
|
||||
|
@ -81,7 +47,7 @@
|
|||
<input type="submit" value="実行"/>
|
||||
</form>
|
||||
<div class="help">
|
||||
<ul style="padding-left: 15px;">
|
||||
<ul class="padding-left">
|
||||
<li>なし/なしで入力画像を変換せずに出力する。ブラウザのタブで変換結果を比較したい人用。
|
||||
<li>Firefoxの方は、右クリから画像が保存できないようなので、CTRL+SキーかALTキー後 ファイル - ページを保存 で画像を保存してください。
|
||||
</ul>
|
||||
|
|
|
@ -2,51 +2,18 @@
|
|||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<link rel="canonical" href="http://waifu2x.udp.jp/">
|
||||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<title>waifu2x</title>
|
||||
<style type="text/css">
|
||||
body {
|
||||
margin: 1em 2em 1em 2em;
|
||||
background: LightGray;
|
||||
width: 640px;
|
||||
}
|
||||
fieldset {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.about {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
font-size: 0.9em;
|
||||
padding: 1em 5px 0.2em 0;
|
||||
}
|
||||
.help {
|
||||
font-size: 0.85em;
|
||||
margin: 1em 0 0 0;
|
||||
}
|
||||
</style>
|
||||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript">
|
||||
function clear_file() {
|
||||
var new_file = $("#file").clone();
|
||||
new_file.change(clear_url);
|
||||
$("#file").replaceWith(new_file);
|
||||
}
|
||||
function clear_url() {
|
||||
$("#url").val("")
|
||||
}
|
||||
$(function (){
|
||||
$("#url").change(clear_file);
|
||||
$("#file").change(clear_url);
|
||||
})
|
||||
</script>
|
||||
<script type="text/javascript" src="ui.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>waifu2x</h1>
|
||||
<div class="header">
|
||||
<div style="position:absolute; display:block; top:0; left:540px; max-height:140px;">
|
||||
<img style="position:absolute; display:block; left:0; top:0; width:149px; height:149px; border:0;" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
|
||||
<a href="https://github.com/nagadomi/waifu2x" target="_blank" style="position:absolute; display:block; left:0; top:0; width:149px; height:130px;"></a>
|
||||
<div class="github-banner">
|
||||
<img class="github-banner-image" src="https://camo.githubusercontent.com/a6677b08c955af8400f44c6298f40e7d19cc5b2d/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f677261795f3664366436642e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_gray_6d6d6d.png">
|
||||
<a class="github-banner-link" href="https://github.com/nagadomi/waifu2x" target="_blank"></a>
|
||||
</div>
|
||||
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
|
||||
</div>
|
||||
|
@ -66,11 +33,11 @@
|
|||
Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<legend>Устранение шума (артефактов JPEG)</legend>
|
||||
<fieldset class="noise-field">
|
||||
<legend>Устранение шума (артефактов JPEG)</legend>
|
||||
<label><input type="radio" name="noise" value="0"> Нет</label>
|
||||
<label><input type="radio" name="noise" value="1" checked="checked"> Средне</label>
|
||||
<label><input type="radio" name="noise" value="2"> Сильно (не рекомендуется)</label>
|
||||
<label><input type="radio" name="noise" value="2"> Сильно</label>
|
||||
<div class="help">Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект. Также не рекомендуется сильное устранение шума, оно даёт выгоду только в редких случаях, когда картинка изначально была сильно испорчена.</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
|
@ -82,8 +49,9 @@
|
|||
<input type="submit"/>
|
||||
</form>
|
||||
<div class="help">
|
||||
<ul style="padding-left: 15px;">
|
||||
<ul class="padding-left">
|
||||
<li>Если Вы используете Firefox, для сохранения изображения Вам придётся нажать Ctrl+S (опция в меню "Сохранить изображение" работать не будет!)
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
|
|
52
assets/style.css
Normal file
|
@ -0,0 +1,52 @@
|
|||
body {
|
||||
margin: 1em 2em 1em 2em;
|
||||
background: LightGray;
|
||||
width: 640px;
|
||||
}
|
||||
fieldset {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
.about {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
font-size: 0.9em;
|
||||
padding: 1em 5px 0.2em 0;
|
||||
}
|
||||
.help {
|
||||
font-size: 0.8em;
|
||||
margin: 1em 0 0 0;
|
||||
}
|
||||
.github-banner {
|
||||
position:absolute;
|
||||
display:block;
|
||||
top:0;
|
||||
left:540px;
|
||||
max-height:140px;
|
||||
}
|
||||
.github-banner-image {
|
||||
position: absolute;
|
||||
display: block;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 149px;
|
||||
height: 149px;
|
||||
border: 0;
|
||||
}
|
||||
.github-banner-link {
|
||||
position: absolute;
|
||||
display: block;
|
||||
left:0;
|
||||
top:0;
|
||||
width:149px;
|
||||
height:130px;
|
||||
}
|
||||
.padding-left {
|
||||
padding-left: 15px;
|
||||
}
|
||||
.hide {
|
||||
display: none;
|
||||
}
|
||||
.experimental {
|
||||
margin-bottom: 1em;
|
||||
}
|
80
assets/ui.js
Normal file
|
@ -0,0 +1,80 @@
|
|||
$(function (){
|
||||
function clear_file() {
|
||||
var new_file = $("#file").clone();
|
||||
new_file.change(clear_url);
|
||||
$("#file").replaceWith(new_file);
|
||||
}
|
||||
function clear_url() {
|
||||
$("#url").val("")
|
||||
}
|
||||
function on_change_style(e) {
|
||||
$("input[name=style]").parents("label").each(
|
||||
function (i, elm) {
|
||||
$(elm).css("font-weight", "normal");
|
||||
});
|
||||
var checked = $("input[name=style]:checked");
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
if (checked.val() == "art") {
|
||||
$("h1").text("waifu2x");
|
||||
} else {
|
||||
$("h1").html("w<s>/a/</s>ifu2x");
|
||||
}
|
||||
}
|
||||
function on_change_noise_level(e)
|
||||
{
|
||||
$("input[name=noise]").parents("label").each(
|
||||
function (i, elm) {
|
||||
$(elm).css("font-weight", "normal");
|
||||
});
|
||||
var checked = $("input[name=noise]:checked");
|
||||
if (checked.val() != 0) {
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
}
|
||||
}
|
||||
function on_change_scale_factor(e)
|
||||
{
|
||||
$("input[name=scale]").parents("label").each(
|
||||
function (i, elm) {
|
||||
$(elm).css("font-weight", "normal");
|
||||
});
|
||||
var checked = $("input[name=scale]:checked");
|
||||
if (checked.val() != 0) {
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
}
|
||||
}
|
||||
function on_change_white_noise(e)
|
||||
{
|
||||
$("input[name=white_noise]").parents("label").each(
|
||||
function (i, elm) {
|
||||
$(elm).css("font-weight", "normal");
|
||||
});
|
||||
var checked = $("input[name=white_noise]:checked");
|
||||
if (checked.val() != 0) {
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
}
|
||||
}
|
||||
function on_click_experimental_button(e)
|
||||
{
|
||||
if ($(this).hasClass("close")) {
|
||||
$(".experimental .container").show();
|
||||
$(this).removeClass("close");
|
||||
} else {
|
||||
$(".experimental .container").hide();
|
||||
$(this).addClass("close");
|
||||
}
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}
|
||||
|
||||
$("#url").change(clear_file);
|
||||
$("#file").change(clear_url);
|
||||
//$("input[name=style]").change(on_change_style);
|
||||
$("input[name=noise]").change(on_change_noise_level);
|
||||
$("input[name=scale]").change(on_change_scale_factor);
|
||||
//$("input[name=white_noise]").change(on_change_white_noise);
|
||||
//$(".experimental .button").click(on_click_experimental_button)
|
||||
|
||||
//on_change_style();
|
||||
on_change_scale_factor();
|
||||
on_change_noise_level();
|
||||
})
|
|
@ -1,48 +1,47 @@
|
|||
require './lib/portable'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
|
||||
require 'pl'
|
||||
require 'image'
|
||||
local settings = require './lib/settings'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local function count_lines(file)
|
||||
local fp = io.open(file, "r")
|
||||
local count = 0
|
||||
for line in fp:lines() do
|
||||
count = count + 1
|
||||
end
|
||||
fp:close()
|
||||
|
||||
return count
|
||||
end
|
||||
|
||||
local function crop_4x(x)
|
||||
local w = x:size(3) % 4
|
||||
local h = x:size(2) % 4
|
||||
return image.crop(x, 0, 0, x:size(3) - w, x:size(2) - h)
|
||||
end
|
||||
local compression = require 'compression'
|
||||
local settings = require 'settings'
|
||||
local image_loader = require 'image_loader'
|
||||
local iproc = require 'iproc'
|
||||
|
||||
local function load_images(list)
|
||||
local count = count_lines(list)
|
||||
local fp = io.open(list, "r")
|
||||
local MARGIN = 32
|
||||
local lines = utils.split(file.read(list), "\n")
|
||||
local x = {}
|
||||
local c = 0
|
||||
for line in fp:lines() do
|
||||
local im = crop_4x(image_loader.load_byte(line))
|
||||
if im then
|
||||
if im:size(2) >= settings.crop_size * 2 and im:size(3) >= settings.crop_size * 2 then
|
||||
table.insert(x, im)
|
||||
end
|
||||
for i = 1, #lines do
|
||||
local line = lines[i]
|
||||
local im, alpha = image_loader.load_byte(line)
|
||||
if alpha then
|
||||
io.stderr:write(string.format("\n%s: skip: image has alpha channel.\n", line))
|
||||
else
|
||||
print("error:" .. line)
|
||||
im = iproc.crop_mod4(im)
|
||||
local scale = 1.0
|
||||
if settings.random_half then
|
||||
scale = 2.0
|
||||
end
|
||||
if im then
|
||||
if im:size(2) > (settings.crop_size * scale + MARGIN) and im:size(3) > (settings.crop_size * scale + MARGIN) then
|
||||
table.insert(x, compression.compress(im))
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", line, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", line))
|
||||
end
|
||||
end
|
||||
c = c + 1
|
||||
xlua.progress(c, count)
|
||||
if c % 10 == 0 then
|
||||
xlua.progress(i, #lines)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
return x
|
||||
end
|
||||
|
||||
torch.manualSeed(settings.seed)
|
||||
print(settings)
|
||||
local x = load_images(settings.image_list)
|
||||
torch.save(settings.images, x)
|
||||
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
require 'cunn'
|
||||
require 'cudnn'
|
||||
require 'cutorch'
|
||||
require './lib/LeakyReLU'
|
||||
local srcnn = require 'lib/srcnn'
|
||||
|
||||
local function cudnn2cunn(cudnn_model)
|
||||
local cunn_model = srcnn.waifu2x("y")
|
||||
local from_seq = cudnn_model:findModules("cudnn.SpatialConvolution")
|
||||
local to_seq = cunn_model:findModules("nn.SpatialConvolutionMM")
|
||||
|
||||
for i = 1, #from_seq do
|
||||
local from = from_seq[i]
|
||||
local to = to_seq[i]
|
||||
to.weight:copy(from.weight)
|
||||
to.bias:copy(from.bias)
|
||||
end
|
||||
cunn_model:cuda()
|
||||
cunn_model:evaluate()
|
||||
return cunn_model
|
||||
end
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("convert cudnn model to cunn model ")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-model", "./model.t7", 'path of cudnn model file')
|
||||
cmd:option("-iformat", "ascii", 'input format')
|
||||
cmd:option("-oformat", "ascii", 'output format')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
local cudnn_model = torch.load(opt.model, opt.iformat)
|
||||
local cunn_model = cudnn2cunn(cudnn_model)
|
||||
torch.save(opt.model, cunn_model, opt.oformat)
|
|
@ -1,23 +0,0 @@
|
|||
-- adapted from https://github.com/marcan/cl-waifu2x
|
||||
require './lib/portable'
|
||||
require './lib/LeakyReLU'
|
||||
local cjson = require "cjson"
|
||||
|
||||
local model = torch.load(arg[1], "ascii")
|
||||
|
||||
local jmodules = {}
|
||||
local modules = model:findModules("nn.SpatialConvolutionMM")
|
||||
for i = 1, #modules, 1 do
|
||||
local module = modules[i]
|
||||
local jmod = {
|
||||
kW = module.kW,
|
||||
kH = module.kH,
|
||||
nInputPlane = module.nInputPlane,
|
||||
nOutputPlane = module.nOutputPlane,
|
||||
bias = torch.totable(module.bias:float()),
|
||||
weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
|
||||
}
|
||||
table.insert(jmodules, jmod)
|
||||
end
|
||||
|
||||
io.write(cjson.encode(jmodules))
|
|
@ -1,8 +1,7 @@
|
|||
#!/bin/sh
|
||||
|
||||
th waifu2x.lua -noise_level 1 -m noise_scale -i images/miku_small.png -o images/miku_small_waifu2x.png
|
||||
th waifu2x.lua -m scale -i images/miku_small.png -o images/miku_small_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_small_noisy.png -o images/miku_small_noisy_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise -i images/miku_noisy.png -o images/miku_noisy_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise_scale -i images/miku_CC_BY-NC_noisy.jpg -o images/miku_CC_BY-NC_noisy_waifu2x.png
|
||||
th waifu2x.lua -noise_level 2 -m noise -i images/lena.png -o images/lena_waifu2x.png
|
||||
th waifu2x.lua -m scale -model_dir models/ukbench -i images/lena.png -o images/lena_waifu2x_ukbench.png
|
||||
|
|
Before Width: | Height: | Size: 315 KiB After Width: | Height: | Size: 397 KiB |
Before Width: | Height: | Size: 1.4 MiB After Width: | Height: | Size: 1.4 MiB |
Before Width: | Height: | Size: 605 KiB After Width: | Height: | Size: 651 KiB |
Before Width: | Height: | Size: 154 KiB After Width: | Height: | Size: 156 KiB |
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 45 KiB |
Before Width: | Height: | Size: 177 KiB After Width: | Height: | Size: 156 KiB |
Before Width: | Height: | Size: 138 KiB After Width: | Height: | Size: 154 KiB |
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 162 KiB |
BIN
images/slide.odp
BIN
images/slide.png
Before Width: | Height: | Size: 1.2 MiB After Width: | Height: | Size: 1.2 MiB |
Before Width: | Height: | Size: 499 KiB After Width: | Height: | Size: 493 KiB |
Before Width: | Height: | Size: 380 KiB After Width: | Height: | Size: 368 KiB |
Before Width: | Height: | Size: 378 KiB After Width: | Height: | Size: 352 KiB |
39
lib/ClippedWeightedHuberCriterion.lua
Normal file
|
@ -0,0 +1,39 @@
|
|||
-- ref: https://en.wikipedia.org/wiki/Huber_loss
|
||||
local ClippedWeightedHuberCriterion, parent = torch.class('w2nn.ClippedWeightedHuberCriterion','nn.Criterion')
|
||||
|
||||
function ClippedWeightedHuberCriterion:__init(w, gamma, clip)
|
||||
parent.__init(self)
|
||||
self.clip = clip
|
||||
self.gamma = gamma or 1.0
|
||||
self.weight = w:clone()
|
||||
self.diff = torch.Tensor()
|
||||
self.diff_abs = torch.Tensor()
|
||||
--self.outlier_rate = 0.0
|
||||
self.square_loss_buff = torch.Tensor()
|
||||
self.linear_loss_buff = torch.Tensor()
|
||||
end
|
||||
function ClippedWeightedHuberCriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
self.diff[torch.lt(self.diff, self.clip[1])] = self.clip[1]
|
||||
self.diff[torch.gt(self.diff, self.clip[2])] = self.clip[2]
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i]):cmul(self.weight)
|
||||
end
|
||||
self.diff_abs:resizeAs(self.diff):copy(self.diff):abs()
|
||||
|
||||
local square_targets = self.diff[torch.lt(self.diff_abs, self.gamma)]
|
||||
local linear_targets = self.diff[torch.ge(self.diff_abs, self.gamma)]
|
||||
local square_loss = self.square_loss_buff:resizeAs(square_targets):copy(square_targets):pow(2.0):mul(0.5):sum()
|
||||
local linear_loss = self.linear_loss_buff:resizeAs(linear_targets):copy(linear_targets):abs():add(-0.5 * self.gamma):mul(self.gamma):sum()
|
||||
|
||||
--self.outlier_rate = linear_targets:nElement() / input:nElement()
|
||||
self.output = (square_loss + linear_loss) / input:nElement()
|
||||
return self.output
|
||||
end
|
||||
function ClippedWeightedHuberCriterion:updateGradInput(input, target)
|
||||
local norm = 1.0 / input:nElement()
|
||||
self.gradInput:resizeAs(self.diff):copy(self.diff):mul(norm)
|
||||
local outlier = torch.ge(self.diff_abs, self.gamma)
|
||||
self.gradInput[outlier] = torch.sign(self.diff[outlier]) * self.gamma * norm
|
||||
return self.gradInput
|
||||
end
|
77
lib/DepthExpand2x.lua
Normal file
|
@ -0,0 +1,77 @@
|
|||
if w2nn.DepthExpand2x then
|
||||
return w2nn.DepthExpand2x
|
||||
end
|
||||
local DepthExpand2x, parent = torch.class('w2nn.DepthExpand2x','nn.Module')
|
||||
|
||||
function DepthExpand2x:__init()
|
||||
parent:__init()
|
||||
end
|
||||
|
||||
function DepthExpand2x:updateOutput(input)
|
||||
local x = input
|
||||
-- (batch_size, depth, height, width)
|
||||
self.shape = x:size()
|
||||
|
||||
assert(self.shape:size() == 4, "input must be 4d tensor")
|
||||
assert(self.shape[2] % 4 == 0, "depth must be depth % 4 = 0")
|
||||
-- (batch_size, width, height, depth)
|
||||
x = x:transpose(2, 4)
|
||||
-- (batch_size, width, height * 2, depth / 2)
|
||||
x = x:reshape(self.shape[1], self.shape[4], self.shape[3] * 2, self.shape[2] / 2)
|
||||
-- (batch_size, height * 2, width, depth / 2)
|
||||
x = x:transpose(2, 3)
|
||||
-- (batch_size, height * 2, width * 2, depth / 4)
|
||||
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4] * 2, self.shape[2] / 4)
|
||||
-- (batch_size, depth / 4, height * 2, width * 2)
|
||||
x = x:transpose(2, 4)
|
||||
x = x:transpose(3, 4)
|
||||
self.output:resizeAs(x):copy(x) -- contiguous
|
||||
|
||||
return self.output
|
||||
end
|
||||
|
||||
function DepthExpand2x:updateGradInput(input, gradOutput)
|
||||
-- (batch_size, depth / 4, height * 2, width * 2)
|
||||
local x = gradOutput
|
||||
-- (batch_size, height * 2, width * 2, depth / 4)
|
||||
x = x:transpose(2, 4)
|
||||
x = x:transpose(2, 3)
|
||||
-- (batch_size, height * 2, width, depth / 2)
|
||||
x = x:reshape(self.shape[1], self.shape[3] * 2, self.shape[4], self.shape[2] / 2)
|
||||
-- (batch_size, width, height * 2, depth / 2)
|
||||
x = x:transpose(2, 3)
|
||||
-- (batch_size, width, height, depth)
|
||||
x = x:reshape(self.shape[1], self.shape[4], self.shape[3], self.shape[2])
|
||||
-- (batch_size, depth, height, width)
|
||||
x = x:transpose(2, 4)
|
||||
|
||||
self.gradInput:resizeAs(x):copy(x)
|
||||
|
||||
return self.gradInput
|
||||
end
|
||||
|
||||
function DepthExpand2x.test()
|
||||
require 'image'
|
||||
local function show(x)
|
||||
local img = torch.Tensor(3, x:size(3), x:size(4))
|
||||
img[1]:copy(x[1][1])
|
||||
img[2]:copy(x[1][2])
|
||||
img[3]:copy(x[1][3])
|
||||
image.display(img)
|
||||
end
|
||||
local img = image.lena()
|
||||
local x = torch.Tensor(1, img:size(1) * 4, img:size(2), img:size(3))
|
||||
for i = 0, img:size(1) * 4 - 1 do
|
||||
src_index = ((i % 3) + 1)
|
||||
x[1][i + 1]:copy(img[src_index])
|
||||
end
|
||||
show(x)
|
||||
|
||||
local de2x = w2nn.DepthExpand2x()
|
||||
out = de2x:forward(x)
|
||||
show(out)
|
||||
out = de2x:updateGradInput(x, out)
|
||||
show(out)
|
||||
end
|
||||
|
||||
return DepthExpand2x
|
|
@ -1,7 +1,8 @@
|
|||
if nn.LeakyReLU then
|
||||
return
|
||||
if w2nn and w2nn.LeakyReLU then
|
||||
return w2nn.LeakyReLU
|
||||
end
|
||||
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
|
||||
|
||||
local LeakyReLU, parent = torch.class('w2nn.LeakyReLU','nn.Module')
|
||||
|
||||
function LeakyReLU:__init(negative_scale)
|
||||
parent.__init(self)
|
||||
|
|
31
lib/LeakyReLU_deprecated.lua
Normal file
|
@ -0,0 +1,31 @@
|
|||
if nn.LeakyReLU then
|
||||
return nn.LeakyReLU
|
||||
end
|
||||
|
||||
local LeakyReLU, parent = torch.class('nn.LeakyReLU','nn.Module')
|
||||
|
||||
function LeakyReLU:__init(negative_scale)
|
||||
parent.__init(self)
|
||||
self.negative_scale = negative_scale or 0.333
|
||||
self.negative = torch.Tensor()
|
||||
end
|
||||
|
||||
function LeakyReLU:updateOutput(input)
|
||||
self.output:resizeAs(input):copy(input):abs():add(input):div(2)
|
||||
self.negative:resizeAs(input):copy(input):abs():add(-1.0, input):mul(-0.5*self.negative_scale)
|
||||
self.output:add(self.negative)
|
||||
|
||||
return self.output
|
||||
end
|
||||
|
||||
function LeakyReLU:updateGradInput(input, gradOutput)
|
||||
self.gradInput:resizeAs(gradOutput)
|
||||
-- filter positive
|
||||
self.negative:sign():add(1)
|
||||
torch.cmul(self.gradInput, gradOutput, self.negative)
|
||||
-- filter negative
|
||||
self.negative:add(-1):mul(-1 * self.negative_scale):cmul(gradOutput)
|
||||
self.gradInput:add(self.negative)
|
||||
|
||||
return self.gradInput
|
||||
end
|
25
lib/WeightedMSECriterion.lua
Normal file
|
@ -0,0 +1,25 @@
|
|||
local WeightedMSECriterion, parent = torch.class('w2nn.WeightedMSECriterion','nn.Criterion')
|
||||
|
||||
function WeightedMSECriterion:__init(w)
|
||||
parent.__init(self)
|
||||
self.weight = w:clone()
|
||||
self.diff = torch.Tensor()
|
||||
self.loss = torch.Tensor()
|
||||
end
|
||||
|
||||
function WeightedMSECriterion:updateOutput(input, target)
|
||||
self.diff:resizeAs(input):copy(input)
|
||||
for i = 1, input:size(1) do
|
||||
self.diff[i]:add(-1, target[i]):cmul(self.weight)
|
||||
end
|
||||
self.loss:resizeAs(self.diff):copy(self.diff):cmul(self.diff)
|
||||
self.output = self.loss:mean()
|
||||
|
||||
return self.output
|
||||
end
|
||||
|
||||
function WeightedMSECriterion:updateGradInput(input, target)
|
||||
local norm = 2.0 / input:nElement()
|
||||
self.gradInput:resizeAs(input):copy(self.diff):mul(norm)
|
||||
return self.gradInput
|
||||
end
|
|
@ -1,9 +1,5 @@
|
|||
require './lib/portable'
|
||||
require './lib/LeakyReLU'
|
||||
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
-- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
|
||||
|
||||
local function zeroDataSize(data)
|
||||
if type(data) == 'table' then
|
||||
for i = 1, #data do
|
||||
|
@ -14,7 +10,6 @@ local function zeroDataSize(data)
|
|||
end
|
||||
return data
|
||||
end
|
||||
|
||||
-- Resize the output, gradInput, etc temporary tensors to zero (so that the
|
||||
-- on disk size is smaller)
|
||||
local function cleanupModel(node)
|
||||
|
@ -27,7 +22,7 @@ local function cleanupModel(node)
|
|||
if node.finput ~= nil then
|
||||
node.finput = zeroDataSize(node.finput)
|
||||
end
|
||||
if tostring(node) == "nn.LeakyReLU" then
|
||||
if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
|
||||
if node.negative ~= nil then
|
||||
node.negative = zeroDataSize(node.negative)
|
||||
end
|
||||
|
@ -46,23 +41,8 @@ local function cleanupModel(node)
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
collectgarbage()
|
||||
end
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("cleanup model")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-model", "./model.t7", 'path of model file')
|
||||
cmd:option("-iformat", "binary", 'input format')
|
||||
cmd:option("-oformat", "binary", 'output format')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
local model = torch.load(opt.model, opt.iformat)
|
||||
if model then
|
||||
function w2nn.cleanup_model(model)
|
||||
cleanupModel(model)
|
||||
torch.save(opt.model, model, opt.oformat)
|
||||
else
|
||||
error("model not found")
|
||||
return model
|
||||
end
|
17
lib/compression.lua
Normal file
|
@ -0,0 +1,17 @@
|
|||
-- snapply compression for ByteTensor
|
||||
require 'snappy'
|
||||
|
||||
local compression = {}
|
||||
compression.compress = function (bt)
|
||||
local enc = snappy.compress(bt:storage():string())
|
||||
return {bt:size(), torch.ByteStorage():string(enc)}
|
||||
end
|
||||
compression.decompress = function(data)
|
||||
local size = data[1]
|
||||
local dec = snappy.decompress(data[2]:string())
|
||||
local bt = torch.ByteTensor(unpack(torch.totable(size)))
|
||||
bt:storage():string(dec)
|
||||
return bt
|
||||
end
|
||||
|
||||
return compression
|
104
lib/data_augmentation.lua
Normal file
|
@ -0,0 +1,104 @@
|
|||
require 'image'
|
||||
local iproc = require 'iproc'
|
||||
|
||||
local data_augmentation = {}
|
||||
|
||||
local function pcacov(x)
|
||||
local mean = torch.mean(x, 1)
|
||||
local xm = x - torch.ger(torch.ones(x:size(1)), mean:squeeze())
|
||||
local c = torch.mm(xm:t(), xm)
|
||||
c:div(x:size(1) - 1)
|
||||
local ce, cv = torch.symeig(c, 'V')
|
||||
return ce, cv
|
||||
end
|
||||
function data_augmentation.color_noise(src, p, factor)
|
||||
factor = factor or 0.1
|
||||
if torch.uniform() < p then
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
|
||||
local ce, cv = pcacov(src_t)
|
||||
local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor)
|
||||
|
||||
pca_space = torch.mm(src_t, cv):t():contiguous()
|
||||
for i = 1, 3 do
|
||||
pca_space[i]:mul(color_scale[i])
|
||||
end
|
||||
local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
|
||||
dest[torch.lt(dest, 0.0)] = 0.0
|
||||
dest[torch.gt(dest, 1.0)] = 1.0
|
||||
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.overlay(src, p)
|
||||
if torch.uniform() < p then
|
||||
local r = torch.uniform()
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
src = src:contiguous()
|
||||
local flip = data_augmentation.flip(src)
|
||||
flip:mul(r):add(src * (1.0 - r))
|
||||
if conversion then
|
||||
flip = iproc.float2byte(flip)
|
||||
end
|
||||
return flip
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||
local direction = torch.random(1, 4)
|
||||
local x_shift = 0
|
||||
local y_shift = 0
|
||||
if direction == 1 then
|
||||
x_shift = 1
|
||||
y_shift = 0
|
||||
elseif direction == 2 then
|
||||
x_shift = 0
|
||||
y_shift = 1
|
||||
elseif direction == 3 then
|
||||
x_shift = 1
|
||||
y_shift = 1
|
||||
elseif flip == 4 then
|
||||
x_shift = 0
|
||||
y_shift = 0
|
||||
end
|
||||
local w = src:size(3) - x_shift
|
||||
local h = src:size(2) - y_shift
|
||||
w = w - (w % 4)
|
||||
h = h - (h % 4)
|
||||
local dest = iproc.crop(src, x_shift, y_shift, x_shift + w, y_shift + h)
|
||||
return dest
|
||||
end
|
||||
function data_augmentation.flip(src)
|
||||
local flip = torch.random(1, 4)
|
||||
local tr = torch.random(1, 2)
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
local dest
|
||||
|
||||
src = src:contiguous()
|
||||
if tr == 1 then
|
||||
-- pass
|
||||
elseif tr == 2 then
|
||||
src = src:transpose(2, 3):contiguous()
|
||||
end
|
||||
if flip == 1 then
|
||||
dest = image.hflip(src)
|
||||
elseif flip == 2 then
|
||||
dest = image.vflip(src)
|
||||
elseif flip == 3 then
|
||||
dest = image.hflip(image.vflip(src))
|
||||
elseif flip == 4 then
|
||||
dest = src
|
||||
end
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
return data_augmentation
|
|
@ -1,74 +1,118 @@
|
|||
local gm = require 'graphicsmagick'
|
||||
local ffi = require 'ffi'
|
||||
local iproc = require 'iproc'
|
||||
require 'pl'
|
||||
|
||||
local image_loader = {}
|
||||
|
||||
function image_loader.decode_float(blob)
|
||||
local im, alpha = image_loader.decode_byte(blob)
|
||||
if im then
|
||||
im = im:float():div(255)
|
||||
end
|
||||
return im, alpha
|
||||
end
|
||||
function image_loader.encode_png(rgb, alpha)
|
||||
if rgb:type() == "torch.ByteTensor" then
|
||||
error("expect FloatTensor")
|
||||
end
|
||||
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||
local clip_eps16 = (1.0 / 65535.0) * 0.5 - (1.0e-7 * (1.0 / 65535.0) * 0.5)
|
||||
local background_color = 0.5
|
||||
|
||||
function image_loader.encode_png(rgb, alpha, depth)
|
||||
depth = depth or 8
|
||||
rgb = iproc.byte2float(rgb)
|
||||
if alpha then
|
||||
if not (alpha:size(2) == rgb:size(2) and alpha:size(3) == rgb:size(3)) then
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "Sinc"):toTensor("float", "I", "DHW")
|
||||
alpha = gm.Image(alpha, "I", "DHW"):size(rgb:size(3), rgb:size(2), "SincFast"):toTensor("float", "I", "DHW")
|
||||
end
|
||||
local rgba = torch.Tensor(4, rgb:size(2), rgb:size(3))
|
||||
rgba[1]:copy(rgb[1])
|
||||
rgba[2]:copy(rgb[2])
|
||||
rgba[3]:copy(rgb[3])
|
||||
rgba[4]:copy(alpha)
|
||||
|
||||
if depth < 16 then
|
||||
rgba:add(clip_eps8)
|
||||
rgba[torch.lt(rgba, 0.0)] = 0.0
|
||||
rgba[torch.gt(rgba, 1.0)] = 1.0
|
||||
else
|
||||
rgba:add(clip_eps16)
|
||||
rgba[torch.lt(rgba, 0.0)] = 0.0
|
||||
rgba[torch.gt(rgba, 1.0)] = 1.0
|
||||
end
|
||||
local im = gm.Image():fromTensor(rgba, "RGBA", "DHW")
|
||||
im:format("png")
|
||||
return im:toBlob(9)
|
||||
return im:depth(depth):format("PNG"):toString(9)
|
||||
else
|
||||
if depth < 16 then
|
||||
rgb = rgb:clone():add(clip_eps8)
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
else
|
||||
rgb = rgb:clone():add(clip_eps16)
|
||||
rgb[torch.lt(rgb, 0.0)] = 0.0
|
||||
rgb[torch.gt(rgb, 1.0)] = 1.0
|
||||
end
|
||||
local im = gm.Image(rgb, "RGB", "DHW")
|
||||
im:format("png")
|
||||
return im:toBlob(9)
|
||||
return im:depth(depth):format("PNG"):toString(9)
|
||||
end
|
||||
end
|
||||
function image_loader.save_png(filename, rgb, alpha)
|
||||
local blob, len = image_loader.encode_png(rgb, alpha)
|
||||
function image_loader.save_png(filename, rgb, alpha, depth)
|
||||
depth = depth or 8
|
||||
local blob = image_loader.encode_png(rgb, alpha, depth)
|
||||
local fp = io.open(filename, "wb")
|
||||
fp:write(ffi.string(blob, len))
|
||||
if not fp then
|
||||
error("IO error: " .. filename)
|
||||
end
|
||||
fp:write(blob)
|
||||
fp:close()
|
||||
return true
|
||||
end
|
||||
function image_loader.decode_byte(blob)
|
||||
function image_loader.decode_float(blob)
|
||||
local load_image = function()
|
||||
local im = gm.Image()
|
||||
local alpha = nil
|
||||
local gamma_lcd = 0.454545
|
||||
|
||||
im:fromBlob(blob, #blob)
|
||||
|
||||
if im:colorspace() == "CMYK" then
|
||||
im:colorspace("RGB")
|
||||
end
|
||||
local gamma = math.floor(im:gamma() * 1000000) / 1000000
|
||||
if gamma ~= 0 and gamma ~= gamma_lcd then
|
||||
im:gammaCorrection(gamma / gamma_lcd)
|
||||
end
|
||||
-- FIXME: How to detect that a image has an alpha channel?
|
||||
if blob:sub(1, 4) == "\x89PNG" or blob:sub(1, 3) == "GIF" then
|
||||
-- split alpha channel
|
||||
im = im:toTensor('float', 'RGBA', 'DHW')
|
||||
local sum_alpha = (im[4] - 1):sum()
|
||||
if sum_alpha > 0 or sum_alpha < 0 then
|
||||
local sum_alpha = (im[4] - 1.0):sum()
|
||||
if sum_alpha < 0 then
|
||||
alpha = im[4]:reshape(1, im:size(2), im:size(3))
|
||||
-- drop full transparent background
|
||||
local mask = torch.le(alpha, 0.0)
|
||||
im[1][mask] = background_color
|
||||
im[2][mask] = background_color
|
||||
im[3][mask] = background_color
|
||||
end
|
||||
local new_im = torch.FloatTensor(3, im:size(2), im:size(3))
|
||||
new_im[1]:copy(im[1])
|
||||
new_im[2]:copy(im[2])
|
||||
new_im[3]:copy(im[3])
|
||||
im = new_im:mul(255):byte()
|
||||
im = new_im
|
||||
else
|
||||
im = im:toTensor('byte', 'RGB', 'DHW')
|
||||
im = im:toTensor('float', 'RGB', 'DHW')
|
||||
end
|
||||
return {im, alpha}
|
||||
return {im, alpha, blob}
|
||||
end
|
||||
local state, ret = pcall(load_image)
|
||||
if state then
|
||||
return ret[1], ret[2]
|
||||
return ret[1], ret[2], ret[3]
|
||||
else
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
end
|
||||
end
|
||||
function image_loader.decode_byte(blob)
|
||||
local im, alpha
|
||||
im, alpha, blob = image_loader.decode_float(blob)
|
||||
|
||||
if im then
|
||||
im = iproc.float2byte(im)
|
||||
-- hmm, alpha does not convert here
|
||||
return im, alpha, blob
|
||||
else
|
||||
return nil, nil, nil
|
||||
end
|
||||
end
|
||||
function image_loader.load_float(file)
|
||||
|
@ -90,18 +134,16 @@ function image_loader.load_byte(file)
|
|||
return image_loader.decode_byte(buff)
|
||||
end
|
||||
local function test()
|
||||
require 'image'
|
||||
local img
|
||||
img = image_loader.load_float("./a.jpg")
|
||||
if img then
|
||||
print(img:min())
|
||||
print(img:max())
|
||||
image.display(img)
|
||||
end
|
||||
img = image_loader.load_float("./b.png")
|
||||
if img then
|
||||
image.display(img)
|
||||
end
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local a = image_loader.load_float("../images/lena.png")
|
||||
local blob = image_loader.encode_png(a)
|
||||
local b = image_loader.decode_float(blob)
|
||||
assert((b - a):abs():sum() == 0)
|
||||
|
||||
a = image_loader.load_byte("../images/lena.png")
|
||||
blob = image_loader.encode_png(a)
|
||||
b = image_loader.decode_byte(blob)
|
||||
assert((b:float() - a:float()):abs():sum() == 0)
|
||||
end
|
||||
--test()
|
||||
return image_loader
|
||||
|
|
120
lib/iproc.lua
|
@ -1,16 +1,78 @@
|
|||
local gm = require 'graphicsmagick'
|
||||
local image = require 'image'
|
||||
local iproc = {}
|
||||
|
||||
function iproc.scale(src, width, height, filter)
|
||||
local t = "float"
|
||||
if src:type() == "torch.ByteTensor" then
|
||||
t = "byte"
|
||||
local iproc = {}
|
||||
local clip_eps8 = (1.0 / 255.0) * 0.5 - (1.0e-7 * (1.0 / 255.0) * 0.5)
|
||||
|
||||
function iproc.crop_mod4(src)
|
||||
local w = src:size(3) % 4
|
||||
local h = src:size(2) % 4
|
||||
return iproc.crop(src, 0, 0, src:size(3) - w, src:size(2) - h)
|
||||
end
|
||||
function iproc.crop(src, w1, h1, w2, h2)
|
||||
local dest
|
||||
if src:dim() == 3 then
|
||||
dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
|
||||
else -- dim == 2
|
||||
dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]:clone()
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.crop_nocopy(src, w1, h1, w2, h2)
|
||||
local dest
|
||||
if src:dim() == 3 then
|
||||
dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
|
||||
else -- dim == 2
|
||||
dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.byte2float(src)
|
||||
local conversion = false
|
||||
local dest = src
|
||||
if src:type() == "torch.ByteTensor" then
|
||||
conversion = true
|
||||
dest = src:float():div(255.0)
|
||||
end
|
||||
return dest, conversion
|
||||
end
|
||||
function iproc.float2byte(src)
|
||||
local conversion = false
|
||||
local dest = src
|
||||
if src:type() == "torch.FloatTensor" then
|
||||
conversion = true
|
||||
dest = (src + clip_eps8):mul(255.0)
|
||||
dest[torch.lt(dest, 0.0)] = 0
|
||||
dest[torch.gt(dest, 255.0)] = 255.0
|
||||
dest = dest:byte()
|
||||
end
|
||||
return dest, conversion
|
||||
end
|
||||
function iproc.scale(src, width, height, filter)
|
||||
local conversion
|
||||
src, conversion = iproc.byte2float(src)
|
||||
filter = filter or "Box"
|
||||
local im = gm.Image(src, "RGB", "DHW")
|
||||
im:size(math.ceil(width), math.ceil(height), filter)
|
||||
return im:toTensor(t, "RGB", "DHW")
|
||||
local dest = im:toTensor("float", "RGB", "DHW")
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.scale_with_gamma22(src, width, height, filter)
|
||||
local conversion
|
||||
src, conversion = iproc.byte2float(src)
|
||||
filter = filter or "Box"
|
||||
local im = gm.Image(src, "RGB", "DHW")
|
||||
im:gammaCorrection(1.0 / 2.2):
|
||||
size(math.ceil(width), math.ceil(height), filter):
|
||||
gammaCorrection(2.2)
|
||||
local dest = im:toTensor("float", "RGB", "DHW")
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.padding(img, w1, w2, h1, h2)
|
||||
local dst_height = img:size(2) + h1 + h2
|
||||
|
@ -22,5 +84,51 @@ function iproc.padding(img, w1, w2, h1, h2)
|
|||
flow[2]:add(-w1)
|
||||
return image.warp(img, flow, "simple", false, "clamp")
|
||||
end
|
||||
function iproc.white_noise(src, std, rgb_weights, gamma)
|
||||
gamma = gamma or 0.454545
|
||||
local conversion
|
||||
src, conversion = iproc.byte2float(src)
|
||||
std = std or 0.01
|
||||
|
||||
local noise = torch.Tensor():resizeAs(src):normal(0, std)
|
||||
if rgb_weights then
|
||||
noise[1]:mul(rgb_weights[1])
|
||||
noise[2]:mul(rgb_weights[2])
|
||||
noise[3]:mul(rgb_weights[3])
|
||||
end
|
||||
|
||||
local dest
|
||||
if gamma ~= 0 then
|
||||
dest = src:clone():pow(gamma):add(noise)
|
||||
dest[torch.lt(dest, 0.0)] = 0.0
|
||||
dest[torch.gt(dest, 1.0)] = 1.0
|
||||
dest:pow(1.0 / gamma)
|
||||
else
|
||||
dest = src + noise
|
||||
end
|
||||
if conversion then
|
||||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
end
|
||||
|
||||
local function test_conversion()
|
||||
local a = torch.linspace(0, 255, 256):float():div(255.0)
|
||||
local b = iproc.float2byte(a)
|
||||
local c = iproc.byte2float(a)
|
||||
local d = torch.linspace(0, 255, 256)
|
||||
assert((a - c):abs():sum() == 0)
|
||||
assert((d:float() - b:float()):abs():sum() == 0)
|
||||
|
||||
a = torch.FloatTensor({256.0, 255.0, 254.999}):div(255.0)
|
||||
b = iproc.float2byte(a)
|
||||
assert(b:float():sum() == 255.0 * 3)
|
||||
|
||||
a = torch.FloatTensor({254.0, 254.499, 253.50001}):div(255.0)
|
||||
b = iproc.float2byte(a)
|
||||
print(b)
|
||||
assert(b:float():sum() == 254.0 * 3)
|
||||
end
|
||||
--test_conversion()
|
||||
|
||||
return iproc
|
||||
|
|
|
@ -21,20 +21,15 @@ local function minibatch_adam(model, criterion,
|
|||
input_size[1], input_size[2], input_size[3])
|
||||
local targets_tmp = torch.Tensor(batch_size,
|
||||
target_size[1] * target_size[2] * target_size[3])
|
||||
|
||||
for t = 1, #train_x, batch_size do
|
||||
if t + batch_size > #train_x then
|
||||
break
|
||||
end
|
||||
for t = 1, #train_x do
|
||||
xlua.progress(t, #train_x)
|
||||
for i = 1, batch_size do
|
||||
local x, y = transformer(train_x[shuffle[t + i - 1]])
|
||||
inputs_tmp[i]:copy(x)
|
||||
targets_tmp[i]:copy(y)
|
||||
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
||||
for i = 1, #xy do
|
||||
inputs_tmp[i]:copy(xy[i][1])
|
||||
targets_tmp[i]:copy(xy[i][2])
|
||||
end
|
||||
inputs:copy(inputs_tmp)
|
||||
targets:copy(targets_tmp)
|
||||
|
||||
local feval = function(x)
|
||||
if x ~= parameters then
|
||||
parameters:copy(x)
|
||||
|
@ -50,13 +45,13 @@ local function minibatch_adam(model, criterion,
|
|||
optim.adam(feval, parameters, config)
|
||||
|
||||
c = c + 1
|
||||
if c % 10 == 0 then
|
||||
if c % 20 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
xlua.progress(#train_x, #train_x)
|
||||
|
||||
return { mse = sum_loss / count_loss}
|
||||
return { loss = sum_loss / count_loss}
|
||||
end
|
||||
|
||||
return minibatch_adam
|
||||
|
|
|
@ -1,69 +1,80 @@
|
|||
require 'image'
|
||||
local gm = require 'graphicsmagick'
|
||||
local iproc = require './iproc'
|
||||
local reconstruct = require './reconstruct'
|
||||
local iproc = require 'iproc'
|
||||
local data_augmentation = require 'data_augmentation'
|
||||
|
||||
local pairwise_transform = {}
|
||||
|
||||
local function random_half(src, p, min_size)
|
||||
p = p or 0.5
|
||||
local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
|
||||
if p > torch.uniform() then
|
||||
local function random_half(src, p)
|
||||
if torch.uniform() < p then
|
||||
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
local function color_augment(x)
|
||||
local color_scale = torch.Tensor(3):uniform(0.8, 1.2)
|
||||
x = x:float():div(255)
|
||||
for i = 1, 3 do
|
||||
x[i]:mul(color_scale[i])
|
||||
end
|
||||
x[torch.lt(x, 0.0)] = 0.0
|
||||
x[torch.gt(x, 1.0)] = 1.0
|
||||
return x:mul(255):byte()
|
||||
end
|
||||
local function flip_augment(x, y)
|
||||
local flip = torch.random(1, 4)
|
||||
if y then
|
||||
if flip == 1 then
|
||||
x = image.hflip(x)
|
||||
y = image.hflip(y)
|
||||
elseif flip == 2 then
|
||||
x = image.vflip(x)
|
||||
y = image.vflip(y)
|
||||
elseif flip == 3 then
|
||||
x = image.hflip(image.vflip(x))
|
||||
y = image.hflip(image.vflip(y))
|
||||
elseif flip == 4 then
|
||||
local function crop_if_large(src, max_size)
|
||||
local tries = 4
|
||||
if src:size(2) > max_size and src:size(3) > max_size then
|
||||
local rect
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, src:size(2) - max_size)
|
||||
local xi = torch.random(0, src:size(3) - max_size)
|
||||
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
-- ignore simple background
|
||||
if rect:float():std() >= 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
return x, y
|
||||
return rect
|
||||
else
|
||||
if flip == 1 then
|
||||
x = image.hflip(x)
|
||||
elseif flip == 2 then
|
||||
x = image.vflip(x)
|
||||
elseif flip == 3 then
|
||||
x = image.hflip(image.vflip(x))
|
||||
elseif flip == 4 then
|
||||
end
|
||||
return x
|
||||
return src
|
||||
end
|
||||
end
|
||||
local INTERPOLATION_PADDING = 16
|
||||
function pairwise_transform.scale(src, scale, size, offset, options)
|
||||
options = options or {color_augment = true, random_half = true, rgb = true}
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
local function preprocess(src, crop_size, options)
|
||||
local dest = src
|
||||
dest = random_half(dest, options.random_half_rate)
|
||||
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
|
||||
return dest
|
||||
end
|
||||
local function active_cropping(x, y, size, p, tries)
|
||||
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
local r = torch.uniform()
|
||||
if p < r then
|
||||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
|
||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
return xc, yc
|
||||
else
|
||||
local best_se = 0.0
|
||||
local best_xc, best_yc
|
||||
local m = torch.FloatTensor(x:size(1), size, size)
|
||||
for i = 1, tries do
|
||||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
|
||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
local xcf = iproc.byte2float(xc)
|
||||
local ycf = iproc.byte2float(yc)
|
||||
local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
|
||||
if se >= best_se then
|
||||
best_xc = xcf
|
||||
best_yc = ycf
|
||||
best_se = se
|
||||
end
|
||||
end
|
||||
return best_xc, best_yc
|
||||
end
|
||||
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
|
||||
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
|
||||
local down_scale = 1.0 / scale
|
||||
local y = image.crop(src,
|
||||
xi - INTERPOLATION_PADDING, yi - INTERPOLATION_PADDING,
|
||||
xi + size + INTERPOLATION_PADDING, yi + size + INTERPOLATION_PADDING)
|
||||
end
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = {
|
||||
"Box", -- 0.012756949974688
|
||||
"Box","Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Cartom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
|
@ -71,221 +82,173 @@ function pairwise_transform.scale(src, scale, size, offset, options)
|
|||
"SincFast", -- 0.014095824314306
|
||||
"Jinc", -- 0.014244299255442
|
||||
}
|
||||
local unstable_region_offset = 8
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
|
||||
y = flip_augment(y)
|
||||
if options.color_augment then
|
||||
y = color_augment(y)
|
||||
end
|
||||
local x = iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
|
||||
x = iproc.scale(x, y:size(3), y:size(2))
|
||||
y = y:float():div(255)
|
||||
x = x:float():div(255)
|
||||
|
||||
if options.rgb then
|
||||
else
|
||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||
end
|
||||
|
||||
y = image.crop(y, INTERPOLATION_PADDING + offset, INTERPOLATION_PADDING + offset, y:size(3) - offset - INTERPOLATION_PADDING, y:size(2) - offset - INTERPOLATION_PADDING)
|
||||
x = image.crop(x, INTERPOLATION_PADDING, INTERPOLATION_PADDING, x:size(3) - INTERPOLATION_PADDING, x:size(2) - INTERPOLATION_PADDING)
|
||||
|
||||
return x, y
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, options)
|
||||
options = options or {color_augment = true, random_half = true, rgb = true}
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
local yi = torch.random(0, src:size(2) - size - 1)
|
||||
local xi = torch.random(0, src:size(3) - size - 1)
|
||||
local y = src
|
||||
local x
|
||||
|
||||
if options.color_augment then
|
||||
y = color_augment(y)
|
||||
end
|
||||
x = y
|
||||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg")
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
|
||||
y = image.crop(y, xi, yi, xi + size, yi + size)
|
||||
x = image.crop(x, xi, yi, xi + size, yi + size)
|
||||
y = y:float():div(255)
|
||||
x = x:float():div(255)
|
||||
x, y = flip_augment(x, y)
|
||||
|
||||
if options.rgb then
|
||||
else
|
||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||
end
|
||||
|
||||
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
||||
end
|
||||
function pairwise_transform.jpeg(src, level, size, offset, options)
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset,
|
||||
options)
|
||||
elseif level == 2 then
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
|
||||
size, offset,
|
||||
options)
|
||||
elseif r > 0.3 then
|
||||
local quality1 = torch.random(37, 70)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset,
|
||||
options)
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
return pairwise_transform.jpeg_(src,
|
||||
{quality1,
|
||||
quality1 - torch.random(5, 15),
|
||||
quality1 - torch.random(15, 25)},
|
||||
size, offset,
|
||||
options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
end
|
||||
function pairwise_transform.jpeg_scale_(src, scale, quality, size, offset, options)
|
||||
if options.random_half then
|
||||
src = random_half(src)
|
||||
end
|
||||
local y = preprocess(src, size, options)
|
||||
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
||||
local down_scale = 1.0 / scale
|
||||
local filters = {
|
||||
"Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Cartom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"SincFast", -- 0.014095824314306
|
||||
"Jinc", -- 0.014244299255442
|
||||
}
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
local yi = torch.random(INTERPOLATION_PADDING, src:size(2) - size - INTERPOLATION_PADDING)
|
||||
local xi = torch.random(INTERPOLATION_PADDING, src:size(3) - size - INTERPOLATION_PADDING)
|
||||
local y = src
|
||||
local x
|
||||
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downscale_filter),
|
||||
y:size(3), y:size(2))
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
if options.color_augment then
|
||||
y = color_augment(y)
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y,
|
||||
size,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc = iproc.byte2float(xc)
|
||||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
x = y
|
||||
x = iproc.scale(x, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter)
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||
local unstable_region_offset = 8
|
||||
local y = preprocess(src, size, options)
|
||||
local x = y
|
||||
|
||||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg")
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
x:format("jpeg"):depth(8)
|
||||
if options.jpeg_sampling_factors == 444 then
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
else -- 420
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
x = iproc.scale(x, y:size(3), y:size(2))
|
||||
y = image.crop(y,
|
||||
xi, yi,
|
||||
xi + size, yi + size)
|
||||
x = image.crop(x,
|
||||
xi, yi,
|
||||
xi + size, yi + size)
|
||||
x = x:float():div(255)
|
||||
y = y:float():div(255)
|
||||
x, y = flip_augment(x, y)
|
||||
|
||||
if options.rgb then
|
||||
else
|
||||
y = image.rgb2yuv(y)[1]:reshape(1, y:size(2), y:size(3))
|
||||
x = image.rgb2yuv(x)[1]:reshape(1, x:size(2), x:size(3))
|
||||
end
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
return x, image.crop(y, offset, offset, size - offset, size - offset)
|
||||
end
|
||||
function pairwise_transform.jpeg_scale(src, scale, level, size, offset, options)
|
||||
options = options or {color_augment = true, random_half = true}
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(65, 85)},
|
||||
size, offset, options)
|
||||
elseif level == 2 then
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_scale_(src, scale, {torch.random(27, 70)},
|
||||
size, offset, options)
|
||||
elseif r > 0.3 then
|
||||
local quality1 = torch.random(37, 70)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_scale_(src, scale, {quality1, quality2},
|
||||
size, offset, options)
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y, size,
|
||||
options.active_cropping_rate,
|
||||
options.active_cropping_tries)
|
||||
xc = iproc.byte2float(xc)
|
||||
yc = iproc.byte2float(yc)
|
||||
if options.rgb then
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
return pairwise_transform.jpeg_scale_(src, scale,
|
||||
{quality1,
|
||||
quality1 - torch.random(5, 15),
|
||||
quality1 - torch.random(15, 25)},
|
||||
size, offset, options)
|
||||
yc = image.rgb2yuv(yc)[1]:reshape(1, yc:size(2), yc:size(3))
|
||||
xc = image.rgb2yuv(xc)[1]:reshape(1, xc:size(2), xc:size(3))
|
||||
end
|
||||
if torch.uniform() < options.nr_rate then
|
||||
-- reducing noise
|
||||
table.insert(batch, {xc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
else
|
||||
-- ratain useful details
|
||||
table.insert(batch, {yc, iproc.crop(yc, offset, offset, size - offset, size - offset)})
|
||||
end
|
||||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
||||
if style == "art" then
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset, n, options)
|
||||
elseif level == 2 then
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
|
||||
size, offset, n, options)
|
||||
elseif r > 0.3 then
|
||||
local quality1 = torch.random(37, 70)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset, n, options)
|
||||
else
|
||||
local quality1 = torch.random(52, 70)
|
||||
local quality2 = quality1 - torch.random(5, 15)
|
||||
local quality3 = quality1 - torch.random(15, 25)
|
||||
|
||||
return pairwise_transform.jpeg_(src,
|
||||
{quality1, quality2, quality3},
|
||||
size, offset, n, options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
elseif style == "photo" then
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 75)},
|
||||
size, offset, n,
|
||||
options)
|
||||
elseif level == 2 then
|
||||
if torch.uniform() > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 60)},
|
||||
size, offset, n, options)
|
||||
else
|
||||
local quality1 = torch.random(40, 60)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset, n, options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
error("unknown style: " .. style)
|
||||
end
|
||||
end
|
||||
|
||||
local function test_jpeg()
|
||||
local loader = require './image_loader'
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
local y, x = pairwise_transform.jpeg_(src, {}, 128, 0, false)
|
||||
image.display({image = y, legend = "y:0"})
|
||||
image.display({image = x, legend = "x:0"})
|
||||
for i = 2, 9 do
|
||||
local y, x = pairwise_transform.jpeg_(pairwise_transform.random_half(src),
|
||||
{i * 10}, 128, 0, {color_augment = false, random_half = true})
|
||||
image.display({image = y, legend = "y:" .. (i * 10), max=1,min=0})
|
||||
image.display({image = x, legend = "x:" .. (i * 10),max=1,min=0})
|
||||
--print(x:mean(), y:mean())
|
||||
function pairwise_transform.test_jpeg(src)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
nr_rate = 1.0,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
rgb = true
|
||||
}
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
for i = 1, 9 do
|
||||
local xy = pairwise_transform.jpeg(src,
|
||||
"art",
|
||||
torch.random(1, 2),
|
||||
128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min=0, max=1})
|
||||
end
|
||||
end
|
||||
function pairwise_transform.test_scale(src)
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
rgb = true
|
||||
}
|
||||
local image = require 'image'
|
||||
local src = image.lena()
|
||||
|
||||
local function test_scale()
|
||||
local loader = require './image_loader'
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
for i = 1, 9 do
|
||||
local y, x = pairwise_transform.scale(src, 2.0, 128, 7, {color_augment = true, random_half = true, rgb = true})
|
||||
image.display({image = y, legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = x, legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||
print(y:size(), x:size())
|
||||
--print(x:mean(), y:mean())
|
||||
for i = 1, 10 do
|
||||
local xy = pairwise_transform.scale(src, 2.0, 128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = xy[1][2], legend = "x:" .. (i * 10), min = 0, max = 1})
|
||||
end
|
||||
end
|
||||
local function test_jpeg_scale()
|
||||
local loader = require './image_loader'
|
||||
local src = loader.load_byte("../images/miku_CC_BY-NC.jpg")
|
||||
for i = 1, 9 do
|
||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 1, 128, 7, {color_augment = true, random_half = true})
|
||||
image.display({image = y, legend = "y1:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = x, legend = "x1:" .. (i * 10), min = 0, max = 1})
|
||||
print(y:size(), x:size())
|
||||
--print(x:mean(), y:mean())
|
||||
end
|
||||
for i = 1, 9 do
|
||||
local y, x = pairwise_transform.jpeg_scale(src, 2.0, 2, 128, 7, {color_augment = true, random_half = true})
|
||||
image.display({image = y, legend = "y2:" .. (i * 10), min = 0, max = 1})
|
||||
image.display({image = x, legend = "x2:" .. (i * 10), min = 0, max = 1})
|
||||
print(y:size(), x:size())
|
||||
--print(x:mean(), y:mean())
|
||||
end
|
||||
end
|
||||
--test_scale()
|
||||
--test_jpeg()
|
||||
--test_jpeg_scale()
|
||||
|
||||
return pairwise_transform
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
require 'torch'
|
||||
require 'cutorch'
|
||||
require 'nn'
|
||||
require 'cunn'
|
|
@ -1,5 +1,5 @@
|
|||
require 'image'
|
||||
local iproc = require './iproc'
|
||||
local iproc = require 'iproc'
|
||||
|
||||
local function reconstruct_y(model, x, offset, block_size)
|
||||
if x:dim() == 2 then
|
||||
|
@ -48,7 +48,8 @@ local function reconstruct_rgb(model, x, offset, block_size)
|
|||
end
|
||||
return new_x
|
||||
end
|
||||
function model_is_rgb(model)
|
||||
local reconstruct = {}
|
||||
function reconstruct.is_rgb(model)
|
||||
if model:get(model:size() - 1).weight:size(1) == 3 then
|
||||
-- 3ch RGB
|
||||
return true
|
||||
|
@ -57,8 +58,23 @@ function model_is_rgb(model)
|
|||
return false
|
||||
end
|
||||
end
|
||||
|
||||
local reconstruct = {}
|
||||
function reconstruct.offset_size(model)
|
||||
local conv = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #conv > 0 then
|
||||
local offset = 0
|
||||
for i = 1, #conv do
|
||||
offset = offset + (conv[i].kW - 1) / 2
|
||||
end
|
||||
return math.floor(offset)
|
||||
else
|
||||
conv = model:findModules("cudnn.SpatialConvolution")
|
||||
local offset = 0
|
||||
for i = 1, #conv do
|
||||
offset = offset + (conv[i].kW - 1) / 2
|
||||
end
|
||||
return math.floor(offset)
|
||||
end
|
||||
end
|
||||
function reconstruct.image_y(model, x, offset, block_size)
|
||||
block_size = block_size or 128
|
||||
local output_size = block_size - offset * 2
|
||||
|
@ -78,7 +94,7 @@ function reconstruct.image_y(model, x, offset, block_size)
|
|||
y[torch.lt(y, 0)] = 0
|
||||
y[torch.gt(y, 1)] = 1
|
||||
yuv[1]:copy(y)
|
||||
local output = image.yuv2rgb(image.crop(yuv,
|
||||
local output = image.yuv2rgb(iproc.crop(yuv,
|
||||
pad_w1, pad_h1,
|
||||
yuv:size(3) - pad_w2, yuv:size(2) - pad_h2))
|
||||
output[torch.lt(output, 0)] = 0
|
||||
|
@ -110,7 +126,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size)
|
|||
y[torch.lt(y, 0)] = 0
|
||||
y[torch.gt(y, 1)] = 1
|
||||
yuv_jinc[1]:copy(y)
|
||||
local output = image.yuv2rgb(image.crop(yuv_jinc,
|
||||
local output = image.yuv2rgb(iproc.crop(yuv_jinc,
|
||||
pad_w1, pad_h1,
|
||||
yuv_jinc:size(3) - pad_w2, yuv_jinc:size(2) - pad_h2))
|
||||
output[torch.lt(output, 0)] = 0
|
||||
|
@ -135,7 +151,7 @@ function reconstruct.image_rgb(model, x, offset, block_size)
|
|||
local pad_w2 = (w - offset) - x:size(3)
|
||||
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||
local y = reconstruct_rgb(model, input, offset, block_size)
|
||||
local output = image.crop(y,
|
||||
local output = iproc.crop(y,
|
||||
pad_w1, pad_h1,
|
||||
y:size(3) - pad_w2, y:size(2) - pad_h2)
|
||||
collectgarbage()
|
||||
|
@ -162,7 +178,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
|||
local pad_w2 = (w - offset) - x:size(3)
|
||||
local input = iproc.padding(x, pad_w1, pad_w2, pad_h1, pad_h2)
|
||||
local y = reconstruct_rgb(model, input, offset, block_size)
|
||||
local output = image.crop(y,
|
||||
local output = iproc.crop(y,
|
||||
pad_w1, pad_h1,
|
||||
y:size(3) - pad_w2, y:size(2) - pad_h2)
|
||||
output[torch.lt(output, 0)] = 0
|
||||
|
@ -172,18 +188,81 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
|||
return output
|
||||
end
|
||||
|
||||
function reconstruct.image(model, x, offset, block_size)
|
||||
if model_is_rgb(model) then
|
||||
return reconstruct.image_rgb(model, x, offset, block_size)
|
||||
function reconstruct.image(model, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
return reconstruct.image_rgb(model, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
else
|
||||
return reconstruct.image_y(model, x, offset, block_size)
|
||||
return reconstruct.image_y(model, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
end
|
||||
function reconstruct.scale(model, scale, x, offset, block_size)
|
||||
if model_is_rgb(model) then
|
||||
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||
function reconstruct.scale(model, scale, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
return reconstruct.scale_rgb(model, scale, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
else
|
||||
return reconstruct.scale_y(model, scale, x, offset, block_size)
|
||||
return reconstruct.scale_y(model, scale, x,
|
||||
reconstruct.offset_size(model), block_size)
|
||||
end
|
||||
end
|
||||
local function tta(f, model, x, block_size)
|
||||
local average = nil
|
||||
local offset = reconstruct.offset_size(model)
|
||||
for i = 1, 4 do
|
||||
local flip_f, iflip_f
|
||||
if i == 1 then
|
||||
flip_f = function (a) return a end
|
||||
iflip_f = function (a) return a end
|
||||
elseif i == 2 then
|
||||
flip_f = image.vflip
|
||||
iflip_f = image.vflip
|
||||
elseif i == 3 then
|
||||
flip_f = image.hflip
|
||||
iflip_f = image.hflip
|
||||
elseif i == 4 then
|
||||
flip_f = function (a) return image.hflip(image.vflip(a)) end
|
||||
iflip_f = function (a) return image.vflip(image.hflip(a)) end
|
||||
end
|
||||
for j = 1, 2 do
|
||||
local tr_f, itr_f
|
||||
if j == 1 then
|
||||
tr_f = function (a) return a end
|
||||
itr_f = function (a) return a end
|
||||
elseif j == 2 then
|
||||
tr_f = function(a) return a:transpose(2, 3):contiguous() end
|
||||
itr_f = function(a) return a:transpose(2, 3):contiguous() end
|
||||
end
|
||||
local out = itr_f(iflip_f(f(model, flip_f(tr_f(x)),
|
||||
offset, block_size)))
|
||||
if not average then
|
||||
average = out
|
||||
else
|
||||
average:add(out)
|
||||
end
|
||||
end
|
||||
end
|
||||
return average:div(8.0)
|
||||
end
|
||||
function reconstruct.image_tta(model, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
return tta(reconstruct.image_rgb, model, x, block_size)
|
||||
else
|
||||
return tta(reconstruct.image_y, model, x, block_size)
|
||||
end
|
||||
end
|
||||
function reconstruct.scale_tta(model, scale, x, block_size)
|
||||
if reconstruct.is_rgb(model) then
|
||||
local f = function (model, x, offset, block_size)
|
||||
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||
end
|
||||
return tta(f, model, x, block_size)
|
||||
|
||||
else
|
||||
local f = function (model, x, offset, block_size)
|
||||
return reconstruct.scale_y(model, scale, x, offset, block_size)
|
||||
end
|
||||
return tta(f, model, x, block_size)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
require 'xlua'
|
||||
require 'pl'
|
||||
require 'trepl'
|
||||
|
||||
-- global settings
|
||||
|
||||
|
@ -14,22 +15,34 @@ local settings = {}
|
|||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x")
|
||||
cmd:text("waifu2x-training")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-seed", 11, 'fixed input seed')
|
||||
cmd:option("-data_dir", "./data", 'data directory')
|
||||
cmd:option("-test", "images/miku_small.png", 'test image file')
|
||||
cmd:option("-gpu", -1, 'GPU Device ID')
|
||||
cmd:option("-seed", 11, 'RNG seed')
|
||||
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||
cmd:option("-model_dir", "./models", 'model directory')
|
||||
cmd:option("-method", "scale", '(noise|scale|noise_scale)')
|
||||
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-style", "art", '(art|photo)')
|
||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||
cmd:option("-scale", 2.0, 'scale')
|
||||
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
|
||||
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
||||
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image')
|
||||
cmd:option("-crop_size", 128, 'crop size')
|
||||
cmd:option("-batch_size", 2, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'epoch')
|
||||
cmd:option("-core", 2, 'cpu core')
|
||||
cmd:option("-crop_size", 46, 'crop size')
|
||||
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
||||
cmd:option("-batch_size", 8, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'number of total epochs to run')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
||||
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
|
||||
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
|
||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
|
@ -53,26 +66,16 @@ end
|
|||
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
||||
error("scale must be mod-2")
|
||||
end
|
||||
if settings.random_half == 1 then
|
||||
settings.random_half = true
|
||||
else
|
||||
settings.random_half = false
|
||||
if not (settings.style == "art" or
|
||||
settings.style == "photo") then
|
||||
error(string.format("unknown style: %s", settings.style))
|
||||
end
|
||||
|
||||
if settings.thread > 0 then
|
||||
torch.setnumthreads(tonumber(settings.thread))
|
||||
end
|
||||
torch.setnumthreads(settings.core)
|
||||
|
||||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
||||
settings.validation_ratio = 0.1
|
||||
settings.validation_crops = 40
|
||||
|
||||
local srcnn = require './srcnn'
|
||||
if (settings.method == "scale" or settings.method == "noise_scale") and settings.scale == 4 then
|
||||
settings.create_model = srcnn.waifu4x
|
||||
settings.block_offset = 13
|
||||
else
|
||||
settings.create_model = srcnn.waifu2x
|
||||
settings.block_offset = 7
|
||||
end
|
||||
|
||||
return settings
|
||||
|
|
|
@ -1,74 +1,68 @@
|
|||
require './LeakyReLU'
|
||||
require 'w2nn'
|
||||
|
||||
function nn.SpatialConvolutionMM:reset(stdv)
|
||||
stdv = math.sqrt(2 / ( self.kW * self.kH * self.nOutputPlane))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:fill(0)
|
||||
end
|
||||
-- ref: http://arxiv.org/abs/1502.01852
|
||||
-- ref: http://arxiv.org/abs/1501.00092
|
||||
local srcnn = {}
|
||||
function srcnn.waifu2x(color)
|
||||
function srcnn.channels(model)
|
||||
return model:get(model:size() - 1).weight:size(1)
|
||||
end
|
||||
function srcnn.waifu2x_cunn(ch)
|
||||
local model = nn.Sequential()
|
||||
local ch = nil
|
||||
if color == "rgb" then
|
||||
ch = 3
|
||||
elseif color == "y" then
|
||||
ch = 1
|
||||
else
|
||||
if color then
|
||||
error("unknown color: " .. color)
|
||||
else
|
||||
error("unknown color: nil")
|
||||
end
|
||||
end
|
||||
|
||||
model:add(nn.SpatialConvolutionMM(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, 1, 92, 92):uniform():cuda()):size())
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model, 7
|
||||
return model
|
||||
end
|
||||
|
||||
-- current 4x is worse than 2x * 2
|
||||
function srcnn.waifu4x(color)
|
||||
function srcnn.waifu2x_cudnn(ch)
|
||||
local model = nn.Sequential()
|
||||
|
||||
local ch = nil
|
||||
model:add(cudnn.SpatialConvolution(ch, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(w2nn.LeakyReLU(0.1))
|
||||
model:add(cudnn.SpatialConvolution(128, ch, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
--model:cuda()
|
||||
--print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size())
|
||||
|
||||
return model
|
||||
end
|
||||
function srcnn.create(model_name, backend, color)
|
||||
local ch = 3
|
||||
if color == "rgb" then
|
||||
ch = 3
|
||||
elseif color == "y" then
|
||||
ch = 1
|
||||
else
|
||||
error("unknown color: " .. color)
|
||||
error("unsupported color: " + color)
|
||||
end
|
||||
if backend == "cunn" then
|
||||
return srcnn.waifu2x_cunn(ch)
|
||||
elseif backend == "cudnn" then
|
||||
return srcnn.waifu2x_cudnn(ch)
|
||||
else
|
||||
error("unsupported backend: " + backend)
|
||||
end
|
||||
|
||||
model:add(nn.SpatialConvolutionMM(ch, 32, 9, 9, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 32, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(32, 64, 5, 5, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 64, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(64, 128, 5, 5, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, 128, 3, 3, 1, 1, 0, 0))
|
||||
model:add(nn.LeakyReLU(0.1))
|
||||
model:add(nn.SpatialConvolutionMM(128, ch, 5, 5, 1, 1, 0, 0))
|
||||
model:add(nn.View(-1):setNumInputDims(3))
|
||||
|
||||
return model, 13
|
||||
end
|
||||
return srcnn
|
||||
|
|
26
lib/w2nn.lua
Normal file
|
@ -0,0 +1,26 @@
|
|||
local function load_nn()
|
||||
require 'torch'
|
||||
require 'nn'
|
||||
end
|
||||
local function load_cunn()
|
||||
require 'cutorch'
|
||||
require 'cunn'
|
||||
end
|
||||
local function load_cudnn()
|
||||
require 'cudnn'
|
||||
cudnn.benchmark = true
|
||||
end
|
||||
if w2nn then
|
||||
return w2nn
|
||||
else
|
||||
pcall(load_cunn)
|
||||
pcall(load_cudnn)
|
||||
w2nn = {}
|
||||
require 'LeakyReLU'
|
||||
require 'LeakyReLU_deprecated'
|
||||
require 'DepthExpand2x'
|
||||
require 'WeightedMSECriterion'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
require 'cleanup_model'
|
||||
return w2nn
|
||||
end
|
169
tools/benchmark.lua
Normal file
|
@ -0,0 +1,169 @@
|
|||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'xlua'
|
||||
require 'w2nn'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x-benchmark")
|
||||
cmd:text("Options:")
|
||||
|
||||
cmd:option("-dir", "./data/test", 'test image directory')
|
||||
cmd:option("-model1_dir", "./models/anime_style_art_rgb", 'model1 directory')
|
||||
cmd:option("-model2_dir", "", 'model2 directory (optional)')
|
||||
cmd:option("-method", "scale", '(scale|noise)')
|
||||
cmd:option("-filter", "Box", "downscaling filter (Box|Jinc)")
|
||||
cmd:option("-color", "rgb", '(rgb|y)')
|
||||
cmd:option("-noise_level", 1, 'model noise level')
|
||||
cmd:option("-jpeg_quality", 75, 'jpeg quality')
|
||||
cmd:option("-jpeg_times", 1, 'jpeg compression times')
|
||||
cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
if cudnn then
|
||||
cudnn.fastest = true
|
||||
cudnn.benchmark = false
|
||||
end
|
||||
|
||||
local function MSE(x1, x2)
|
||||
return (x1 - x2):pow(2):mean()
|
||||
end
|
||||
local function YMSE(x1, x2)
|
||||
local x1_2 = image.rgb2y(x1)
|
||||
local x2_2 = image.rgb2y(x2)
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
end
|
||||
local function PSNR(x1, x2)
|
||||
local mse = MSE(x1, x2)
|
||||
return 10 * math.log10(1.0 / mse)
|
||||
end
|
||||
local function YPSNR(x1, x2)
|
||||
local mse = YMSE(x1, x2)
|
||||
return 10 * math.log10(1.0 / mse)
|
||||
end
|
||||
|
||||
local function transform_jpeg(x, opt)
|
||||
for i = 1, opt.jpeg_times do
|
||||
jpeg = gm.Image(x, "RGB", "DHW")
|
||||
jpeg:format("jpeg")
|
||||
jpeg:samplingFactors({1.0, 1.0, 1.0})
|
||||
blob, len = jpeg:toBlob(opt.jpeg_quality - (i - 1) * opt.jpeg_quality_down)
|
||||
jpeg:fromBlob(blob, len)
|
||||
x = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
return x
|
||||
end
|
||||
local function transform_scale(x, opt)
|
||||
return iproc.scale(x,
|
||||
x:size(3) * 0.5,
|
||||
x:size(2) * 0.5,
|
||||
opt.filter)
|
||||
end
|
||||
|
||||
local function benchmark(opt, x, input_func, model1, model2)
|
||||
local model1_mse = 0
|
||||
local model2_mse = 0
|
||||
local model1_psnr = 0
|
||||
local model2_psnr = 0
|
||||
|
||||
for i = 1, #x do
|
||||
local ground_truth = x[i]
|
||||
local input, model1_output, model2_output
|
||||
|
||||
input = input_func(ground_truth, opt)
|
||||
input = input:float():div(255)
|
||||
ground_truth = ground_truth:float():div(255)
|
||||
|
||||
t = sys.clock()
|
||||
if input:size(3) == ground_truth:size(3) then
|
||||
model1_output = reconstruct.image(model1, input)
|
||||
if model2 then
|
||||
model2_output = reconstruct.image(model2, input)
|
||||
end
|
||||
else
|
||||
model1_output = reconstruct.scale(model1, 2.0, input)
|
||||
if model2 then
|
||||
model2_output = reconstruct.scale(model2, 2.0, input)
|
||||
end
|
||||
end
|
||||
if opt.color == "y" then
|
||||
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
|
||||
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
||||
end
|
||||
elseif opt.color == "rgb" then
|
||||
model1_mse = model1_mse + MSE(ground_truth, model1_output)
|
||||
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
||||
end
|
||||
else
|
||||
error("Unknown color: " .. opt.color)
|
||||
end
|
||||
if model2 then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model2_mse / i,
|
||||
model1_psnr / i, model2_psnr / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model1_psnr / i
|
||||
))
|
||||
end
|
||||
io.stdout:flush()
|
||||
end
|
||||
io.stdout:write("\n")
|
||||
end
|
||||
local function load_data(test_dir)
|
||||
local test_x = {}
|
||||
local files = dir.getfiles(test_dir, "*.*")
|
||||
for i = 1, #files do
|
||||
table.insert(test_x, iproc.crop_mod4(image_loader.load_byte(files[i])))
|
||||
xlua.progress(i, #files)
|
||||
end
|
||||
return test_x
|
||||
end
|
||||
function load_model(filename)
|
||||
return torch.load(filename, "ascii")
|
||||
end
|
||||
print(opt)
|
||||
if opt.method == "scale" then
|
||||
local f1 = path.join(opt.model1_dir, "scale2.0x_model.t7")
|
||||
local f2 = path.join(opt.model2_dir, "scale2.0x_model.t7")
|
||||
local s1, model1 = pcall(load_model, f1)
|
||||
local s2, model2 = pcall(load_model, f2)
|
||||
if not s1 then
|
||||
error("Load error: " .. f1)
|
||||
end
|
||||
if not s2 then
|
||||
model2 = nil
|
||||
end
|
||||
local test_x = load_data(opt.dir)
|
||||
benchmark(opt, test_x, transform_scale, model1, model2)
|
||||
elseif opt.method == "noise" then
|
||||
local f1 = path.join(opt.model1_dir, string.format("noise%d_model.t7", opt.noise_level))
|
||||
local f2 = path.join(opt.model2_dir, string.format("noise%d_model.t7", opt.noise_level))
|
||||
local s1, model1 = pcall(load_model, f1)
|
||||
local s2, model2 = pcall(load_model, f2)
|
||||
if not s1 then
|
||||
error("Load error: " .. f1)
|
||||
end
|
||||
if not s2 then
|
||||
model2 = nil
|
||||
end
|
||||
local test_x = load_data(opt.dir)
|
||||
benchmark(opt, test_x, transform_jpeg, model1, model2)
|
||||
end
|
25
tools/cleanup_model.lua
Normal file
|
@ -0,0 +1,25 @@
|
|||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
|
||||
require 'w2nn'
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("cleanup model")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-model", "./model.t7", 'path of model file')
|
||||
cmd:option("-iformat", "binary", 'input format')
|
||||
cmd:option("-oformat", "binary", 'output format')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
local model = torch.load(opt.model, opt.iformat)
|
||||
if model then
|
||||
w2nn.cleanup_model(model)
|
||||
model:cuda()
|
||||
model:evaluate()
|
||||
torch.save(opt.model, model, opt.oformat)
|
||||
else
|
||||
error("model not found")
|
||||
end
|
43
tools/cudnn2cunn.lua
Normal file
|
@ -0,0 +1,43 @@
|
|||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'os'
|
||||
require 'w2nn'
|
||||
local srcnn = require 'srcnn'
|
||||
|
||||
local function cudnn2cunn(cudnn_model)
|
||||
local cunn_model = srcnn.waifu2x_cunn(srcnn.channels(cudnn_model))
|
||||
local weight_from = cudnn_model:findModules("cudnn.SpatialConvolution")
|
||||
local weight_to = cunn_model:findModules("nn.SpatialConvolutionMM")
|
||||
|
||||
assert(#weight_from == #weight_to)
|
||||
|
||||
for i = 1, #weight_from do
|
||||
local from = weight_from[i]
|
||||
local to = weight_to[i]
|
||||
|
||||
to.weight:copy(from.weight)
|
||||
to.bias:copy(from.bias)
|
||||
end
|
||||
cunn_model:cuda()
|
||||
cunn_model:evaluate()
|
||||
return cunn_model
|
||||
end
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x cudnn model to cunn model converter")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-i", "", 'Specify the input cunn model')
|
||||
cmd:option("-o", "", 'Specify the output cudnn model')
|
||||
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
|
||||
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if not path.isfile(opt.i) then
|
||||
cmd:help()
|
||||
os.exit(-1)
|
||||
end
|
||||
local cudnn_model = torch.load(opt.i, opt.iformat)
|
||||
local cunn_model = cudnn2cunn(cudnn_model)
|
||||
torch.save(opt.o, cunn_model, opt.oformat)
|
43
tools/cunn2cudnn.lua
Normal file
|
@ -0,0 +1,43 @@
|
|||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'os'
|
||||
require 'w2nn'
|
||||
local srcnn = require 'srcnn'
|
||||
|
||||
local function cunn2cudnn(cunn_model)
|
||||
local cudnn_model = srcnn.waifu2x_cudnn(srcnn.channels(cunn_model))
|
||||
local weight_from = cunn_model:findModules("nn.SpatialConvolutionMM")
|
||||
local weight_to = cudnn_model:findModules("cudnn.SpatialConvolution")
|
||||
|
||||
assert(#weight_from == #weight_to)
|
||||
|
||||
for i = 1, #weight_from do
|
||||
local from = weight_from[i]
|
||||
local to = weight_to[i]
|
||||
|
||||
to.weight:copy(from.weight)
|
||||
to.bias:copy(from.bias)
|
||||
end
|
||||
cudnn_model:cuda()
|
||||
cudnn_model:evaluate()
|
||||
return cudnn_model
|
||||
end
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x cunn model to cudnn model converter")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-i", "", 'Specify the input cudnn model')
|
||||
cmd:option("-o", "", 'Specify the output cunn model')
|
||||
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
|
||||
cmd:option("-oformat", "ascii", 'Specify the output format (ascii|binary)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if not path.isfile(opt.i) then
|
||||
cmd:help()
|
||||
os.exit(-1)
|
||||
end
|
||||
local cunn_model = torch.load(opt.i, opt.iformat)
|
||||
local cudnn_model = cunn2cudnn(cunn_model)
|
||||
torch.save(opt.o, cudnn_model, opt.oformat)
|
54
tools/export_model.lua
Normal file
|
@ -0,0 +1,54 @@
|
|||
-- adapted from https://github.com/marcan/cl-waifu2x
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'w2nn'
|
||||
local cjson = require "cjson"
|
||||
|
||||
function export(model, output)
|
||||
local jmodules = {}
|
||||
local modules = model:findModules("nn.SpatialConvolutionMM")
|
||||
if #modules == 0 then
|
||||
-- cudnn model
|
||||
modules = model:findModules("cudnn.SpatialConvolution")
|
||||
end
|
||||
for i = 1, #modules, 1 do
|
||||
local module = modules[i]
|
||||
local jmod = {
|
||||
kW = module.kW,
|
||||
kH = module.kH,
|
||||
nInputPlane = module.nInputPlane,
|
||||
nOutputPlane = module.nOutputPlane,
|
||||
bias = torch.totable(module.bias:float()),
|
||||
weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
|
||||
}
|
||||
table.insert(jmodules, jmod)
|
||||
end
|
||||
jmodules[1].color = "RGB"
|
||||
jmodules[1].gamma = 0
|
||||
jmodules[#jmodules].color = "RGB"
|
||||
jmodules[#jmodules].gamma = 0
|
||||
|
||||
local fp = io.open(output, "w")
|
||||
if not fp then
|
||||
error("IO Error: " .. output)
|
||||
end
|
||||
fp:write(cjson.encode(jmodules))
|
||||
fp:close()
|
||||
end
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("waifu2x export model")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-i", "input.t7", 'Specify the input torch model')
|
||||
cmd:option("-o", "output.json", 'Specify the output json file')
|
||||
cmd:option("-iformat", "ascii", 'Specify the input format (ascii|binary)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if not path.isfile(opt.i) then
|
||||
cmd:help()
|
||||
os.exit(-1)
|
||||
end
|
||||
local model = torch.load(opt.i, opt.iformat)
|
||||
export(model, opt.o)
|
175
train.lua
|
@ -1,21 +1,25 @@
|
|||
require './lib/portable'
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
require 'optim'
|
||||
require 'xlua'
|
||||
require 'pl'
|
||||
|
||||
local settings = require './lib/settings'
|
||||
local minibatch_adam = require './lib/minibatch_adam'
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local pairwise_transform = require './lib/pairwise_transform'
|
||||
local image_loader = require './lib/image_loader'
|
||||
require 'w2nn'
|
||||
local settings = require 'settings'
|
||||
local srcnn = require 'srcnn'
|
||||
local minibatch_adam = require 'minibatch_adam'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local compression = require 'compression'
|
||||
local pairwise_transform = require 'pairwise_transform'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
local function save_test_scale(model, rgb, file)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb, settings.block_offset)
|
||||
local up = reconstruct.scale(model, settings.scale, rgb)
|
||||
image.save(file, up)
|
||||
end
|
||||
local function save_test_jpeg(model, rgb, file)
|
||||
local im, count = reconstruct.image(model, rgb, settings.block_offset)
|
||||
local im, count = reconstruct.image(model, rgb)
|
||||
image.save(file, im)
|
||||
end
|
||||
local function split_data(x, test_size)
|
||||
|
@ -31,14 +35,19 @@ local function split_data(x, test_size)
|
|||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
local function make_validation_set(x, transformer, n)
|
||||
local function make_validation_set(x, transformer, n, batch_size)
|
||||
n = n or 4
|
||||
local data = {}
|
||||
for i = 1, #x do
|
||||
for k = 1, n do
|
||||
local x, y = transformer(x[i], true)
|
||||
table.insert(data, {x = x:reshape(1, x:size(1), x:size(2), x:size(3)),
|
||||
y = y:reshape(1, y:size(1), y:size(2), y:size(3))})
|
||||
for k = 1, math.max(n / batch_size, 1) do
|
||||
local xy = transformer(x[i], true, batch_size)
|
||||
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
||||
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
||||
for j = 1, #xy do
|
||||
tx[j]:copy(xy[j][1])
|
||||
ty[j]:copy(xy[j][2])
|
||||
end
|
||||
table.insert(data, {x = tx, y = ty})
|
||||
end
|
||||
xlua.progress(i, #x)
|
||||
collectgarbage()
|
||||
|
@ -50,24 +59,92 @@ local function validate(model, criterion, data)
|
|||
for i = 1, #data do
|
||||
local z = model:forward(data[i].x:cuda())
|
||||
loss = loss + criterion:forward(z, data[i].y:cuda())
|
||||
xlua.progress(i, #data)
|
||||
if i % 10 == 0 then
|
||||
if i % 100 == 0 then
|
||||
xlua.progress(i, #data)
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
xlua.progress(#data, #data)
|
||||
return loss / #data
|
||||
end
|
||||
|
||||
local function create_criterion(model)
|
||||
if reconstruct.is_rgb(model) then
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(3, output_w * output_w)
|
||||
weight[1]:fill(0.29891 * 3) -- R
|
||||
weight[2]:fill(0.58661 * 3) -- G
|
||||
weight[3]:fill(0.11448 * 3) -- B
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
else
|
||||
return nn.MSECriterion():cuda()
|
||||
end
|
||||
end
|
||||
local function transformer(x, is_validation, n, offset)
|
||||
x = compression.decompress(x)
|
||||
n = n or settings.batch_size;
|
||||
if is_validation == nil then is_validation = false end
|
||||
local random_color_noise_rate = nil
|
||||
local random_overlay_rate = nil
|
||||
local active_cropping_rate = nil
|
||||
local active_cropping_tries = nil
|
||||
if is_validation then
|
||||
active_cropping_rate = 0
|
||||
active_cropping_tries = 0
|
||||
random_color_noise_rate = 0.0
|
||||
random_overlay_rate = 0.0
|
||||
else
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
random_color_noise_rate = settings.random_color_noise_rate
|
||||
random_overlay_rate = settings.random_overlay_rate
|
||||
end
|
||||
|
||||
if settings.method == "scale" then
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise" then
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
nr_rate = settings.nr_rate,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
end
|
||||
end
|
||||
|
||||
local function train()
|
||||
local model, offset = settings.create_model(settings.color)
|
||||
assert(offset == settings.block_offset)
|
||||
local criterion = nn.MSECriterion():cuda()
|
||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local pairwise_func = function(x, is_validation, n)
|
||||
return transformer(x, is_validation, n, offset)
|
||||
end
|
||||
local criterion = create_criterion(model)
|
||||
local x = torch.load(settings.images)
|
||||
local lrd_count = 0
|
||||
local train_x, valid_x = split_data(x,
|
||||
math.floor(settings.validation_ratio * #x),
|
||||
settings.validation_crops)
|
||||
local test = image_loader.load_float(settings.test)
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||
local adam_config = {
|
||||
learningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
|
@ -78,38 +155,11 @@ local function train()
|
|||
elseif settings.color == "rgb" then
|
||||
ch = 3
|
||||
end
|
||||
local transformer = function(x, is_validation)
|
||||
if is_validation == nil then is_validation = false end
|
||||
if settings.method == "scale" then
|
||||
return pairwise_transform.scale(x,
|
||||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
{ color_augment = not is_validation,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise" then
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
{ color_augment = not is_validation,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
elseif settings.method == "noise_scale" then
|
||||
return pairwise_transform.jpeg_scale(x,
|
||||
settings.scale,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
{ color_augment = not is_validation,
|
||||
random_half = settings.random_half,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
end
|
||||
end
|
||||
local best_score = 100000.0
|
||||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, transformer, 20)
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||
settings.validation_crops,
|
||||
settings.batch_size)
|
||||
valid_x = nil
|
||||
|
||||
collectgarbage()
|
||||
|
@ -119,7 +169,7 @@ local function train()
|
|||
model:training()
|
||||
print("# " .. epoch)
|
||||
print(minibatch_adam(model, criterion, train_x, adam_config,
|
||||
transformer,
|
||||
pairwise_func,
|
||||
{ch, settings.crop_size, settings.crop_size},
|
||||
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
||||
))
|
||||
|
@ -127,6 +177,7 @@ local function train()
|
|||
print("# validation")
|
||||
local score = validate(model, criterion, valid_xy)
|
||||
if score < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
best_score = score
|
||||
print("* update best model")
|
||||
|
@ -134,22 +185,17 @@ local function train()
|
|||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
save_test_jpeg(model, test, log)
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.png"):format(settings.scale))
|
||||
save_test_scale(model, test, log)
|
||||
elseif settings.method == "noise_scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_scale%.1f_best.png"):format(settings.noise_level,
|
||||
settings.scale))
|
||||
save_test_scale(model, test, log)
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
else
|
||||
lrd_count = lrd_count + 1
|
||||
if lrd_count > 5 then
|
||||
lrd_count = 0
|
||||
adam_config.learningRate = adam_config.learningRate * 0.8
|
||||
adam_config.learningRate = adam_config.learningRate * 0.9
|
||||
print("* learning rate decay: " .. adam_config.learningRate)
|
||||
end
|
||||
end
|
||||
|
@ -157,6 +203,9 @@ local function train()
|
|||
collectgarbage()
|
||||
end
|
||||
end
|
||||
if settings.gpu > 0 then
|
||||
cutorch.setDevice(settings.gpu)
|
||||
end
|
||||
torch.manualSeed(settings.seed)
|
||||
cutorch.manualSeed(settings.seed)
|
||||
print(settings)
|
||||
|
|
14
train.sh
|
@ -1,10 +1,12 @@
|
|||
#!/bin/sh
|
||||
|
||||
th train.lua -color rgb -method noise -noise_level 1 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||
th convert_data.lua
|
||||
|
||||
th train.lua -color rgb -method noise -noise_level 2 -model_dir models/anime_style_art_rgb -test images/miku_noisy.png
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -color rgb -method scale -scale 2 -model_dir models/anime_style_art_rgb -test images/miku_small.png
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||
|
|
9
train_ukbench.sh
Executable file
|
@ -0,0 +1,9 @@
|
|||
#!/bin/sh
|
||||
|
||||
th convert_data.lua -data_dir ./data/ukbench
|
||||
|
||||
#th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.png -nr_rate 0.9 -jpeg_sampling_factors 420 # -thread 4 -backend cudnn
|
||||
#th tools/cleanup_model.lua -model models/ukbench/noise2_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg # -thread 4 -backend cudnn
|
||||
th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii
|
166
waifu2x.lua
|
@ -1,12 +1,11 @@
|
|||
require './lib/portable'
|
||||
require 'sys'
|
||||
require 'pl'
|
||||
require './lib/LeakyReLU'
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
local BLOCK_OFFSET = 7
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||
require 'sys'
|
||||
require 'w2nn'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
||||
|
@ -14,43 +13,109 @@ local function convert_image(opt)
|
|||
local x, alpha = image_loader.load_float(opt.i)
|
||||
local new_x = nil
|
||||
local t = sys.clock()
|
||||
local scale_f, image_f
|
||||
if opt.tta == 1 then
|
||||
scale_f = reconstruct.scale_tta
|
||||
image_f = reconstruct.image_tta
|
||||
else
|
||||
scale_f = reconstruct.scale
|
||||
image_f = reconstruct.image
|
||||
end
|
||||
if opt.o == "(auto)" then
|
||||
local name = path.basename(opt.i)
|
||||
local e = path.extension(name)
|
||||
local base = name:sub(0, name:len() - e:len())
|
||||
opt.o = path.join(path.dirname(opt.i), string.format("%s(%s).png", base, opt.m))
|
||||
opt.o = path.join(path.dirname(opt.i), string.format("%s_%s.png", base, opt.m))
|
||||
end
|
||||
if opt.m == "noise" then
|
||||
local model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
|
||||
model:evaluate()
|
||||
new_x = reconstruct.image(model, x, BLOCK_OFFSET, opt.crop_size)
|
||||
local model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
|
||||
local model = torch.load(model_path, "ascii")
|
||||
if not model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
new_x = image_f(model, x, opt.crop_size)
|
||||
elseif opt.m == "scale" then
|
||||
local model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
||||
model:evaluate()
|
||||
new_x = reconstruct.scale(model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
local model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
local model = torch.load(model_path, "ascii")
|
||||
if not model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
new_x = scale_f(model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" then
|
||||
local noise_model = torch.load(path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level)), "ascii")
|
||||
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
||||
noise_model:evaluate()
|
||||
scale_model:evaluate()
|
||||
x = reconstruct.image(noise_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
local noise_model_path = path.join(opt.model_dir, ("noise%d_model.t7"):format(opt.noise_level))
|
||||
local noise_model = torch.load(noise_model_path, "ascii")
|
||||
local scale_model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
local scale_model = torch.load(scale_model_path, "ascii")
|
||||
|
||||
if not noise_model then
|
||||
error("Load Error: " .. noise_model_path)
|
||||
end
|
||||
if not scale_model then
|
||||
error("Load Error: " .. scale_model_path)
|
||||
end
|
||||
x = image_f(noise_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
image_loader.save_png(opt.o, new_x, alpha)
|
||||
if opt.white_noise == 1 then
|
||||
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
|
||||
end
|
||||
image_loader.save_png(opt.o, new_x, alpha, opt.depth)
|
||||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||
end
|
||||
local function convert_frames(opt)
|
||||
local noise1_model = torch.load(path.join(opt.model_dir, "noise1_model.t7"), "ascii")
|
||||
local noise2_model = torch.load(path.join(opt.model_dir, "noise2_model.t7"), "ascii")
|
||||
local scale_model = torch.load(path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale)), "ascii")
|
||||
|
||||
noise1_model:evaluate()
|
||||
noise2_model:evaluate()
|
||||
scale_model:evaluate()
|
||||
|
||||
local model_path, noise1_model, noise2_model, scale_model
|
||||
local scale_f, image_f
|
||||
if opt.tta == 1 then
|
||||
scale_f = reconstruct.scale_tta
|
||||
image_f = reconstruct.image_tta
|
||||
else
|
||||
scale_f = reconstruct.scale
|
||||
image_f = reconstruct.image
|
||||
end
|
||||
if opt.m == "scale" then
|
||||
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
scale_model = torch.load(model_path, "ascii")
|
||||
if not scale_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise" and opt.noise_level == 1 then
|
||||
model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
noise1_model = torch.load(model_path, "ascii")
|
||||
if not noise1_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
noise2_model = torch.load(model_path, "ascii")
|
||||
if not noise2_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise_scale" then
|
||||
model_path = path.join(opt.model_dir, ("scale%.1fx_model.t7"):format(opt.scale))
|
||||
scale_model = torch.load(model_path, "ascii")
|
||||
if not scale_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
if opt.noise_level == 1 then
|
||||
model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
noise1_model = torch.load(model_path, "ascii")
|
||||
if not noise1_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.noise_level == 2 then
|
||||
model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
noise2_model = torch.load(model_path, "ascii")
|
||||
if not noise2_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
end
|
||||
end
|
||||
local fp = io.open(opt.l)
|
||||
if not fp then
|
||||
error("Open Error: " .. opt.l)
|
||||
end
|
||||
local count = 0
|
||||
local lines = {}
|
||||
for line in fp:lines() do
|
||||
|
@ -62,20 +127,24 @@ local function convert_frames(opt)
|
|||
local x, alpha = image_loader.load_float(lines[i])
|
||||
local new_x = nil
|
||||
if opt.m == "noise" and opt.noise_level == 1 then
|
||||
new_x = reconstruct.image(noise1_model, x, BLOCK_OFFSET, opt.crop_size)
|
||||
new_x = image_f(noise1_model, x, opt.crop_size)
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
new_x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||
new_x = image_func(noise2_model, x, opt.crop_size)
|
||||
elseif opt.m == "scale" then
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
||||
x = reconstruct.image(noise1_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
x = image_f(noise1_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
||||
x = reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||
new_x = reconstruct.scale(scale_model, opt.scale, x, BLOCK_OFFSET, opt.crop_size)
|
||||
x = image_f(noise2_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
else
|
||||
error("undefined method:" .. opt.method)
|
||||
end
|
||||
if opt.white_noise == 1 then
|
||||
new_x = iproc.white_noise(new_x, opt.white_noise_std, {1.0, 0.8, 1.0})
|
||||
end
|
||||
|
||||
local output = nil
|
||||
if opt.o == "(auto)" then
|
||||
local name = path.basename(lines[i])
|
||||
|
@ -85,7 +154,7 @@ local function convert_frames(opt)
|
|||
else
|
||||
output = string.format(opt.o, i)
|
||||
end
|
||||
image_loader.save_png(output, new_x, alpha)
|
||||
image_loader.save_png(output, new_x, alpha, opt.depth)
|
||||
xlua.progress(i, #lines)
|
||||
if i % 10 == 0 then
|
||||
collectgarbage()
|
||||
|
@ -101,17 +170,30 @@ local function waifu2x()
|
|||
cmd:text()
|
||||
cmd:text("waifu2x")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-i", "images/miku_small.png", 'path of the input image')
|
||||
cmd:option("-l", "", 'path of the image-list')
|
||||
cmd:option("-i", "images/miku_small.png", 'path to input image')
|
||||
cmd:option("-l", "", 'path to image-list.txt')
|
||||
cmd:option("-scale", 2, 'scale factor')
|
||||
cmd:option("-o", "(auto)", 'path of the output file')
|
||||
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
|
||||
cmd:option("-o", "(auto)", 'path to output file')
|
||||
cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
|
||||
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
|
||||
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-crop_size", 128, 'patch size per process')
|
||||
cmd:option("-resume", 0, "skip existing files (0|1)")
|
||||
cmd:option("-thread", -1, "number of CPU threads")
|
||||
cmd:option("-tta", 0, '8x slower and slightly high quality (0|1)')
|
||||
cmd:option("-white_noise", 0, 'adding white noise to output image (0|1)')
|
||||
cmd:option("-white_noise_std", 0.0055, 'standard division of white noise')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
if opt.thread > 0 then
|
||||
torch.setnumthreads(opt.thread)
|
||||
end
|
||||
if cudnn then
|
||||
cudnn.fastest = true
|
||||
cudnn.benchmark = false
|
||||
end
|
||||
|
||||
if string.len(opt.l) == 0 then
|
||||
convert_image(opt)
|
||||
else
|
||||
|
|
209
web.lua
|
@ -1,11 +1,21 @@
|
|||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
local ROOT = path.dirname(__FILE__)
|
||||
package.path = path.join(ROOT, "lib", "?.lua;") .. package.path
|
||||
_G.TURBO_SSL = true
|
||||
local turbo = require 'turbo'
|
||||
|
||||
require 'w2nn'
|
||||
local uuid = require 'uuid'
|
||||
local ffi = require 'ffi'
|
||||
local md5 = require 'md5'
|
||||
require 'pl'
|
||||
require './lib/portable'
|
||||
require './lib/LeakyReLU'
|
||||
local iproc = require 'iproc'
|
||||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
|
||||
-- Notes: turbo and xlua has different implementation of string:split().
|
||||
-- Therefore, string:split() has conflict issue.
|
||||
-- In this script, use turbo's string:split().
|
||||
local turbo = require 'turbo'
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
|
@ -13,24 +23,27 @@ cmd:text("waifu2x-api")
|
|||
cmd:text("Options:")
|
||||
cmd:option("-port", 8812, 'listen port')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-core", 2, 'number of CPU cores')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
local opt = cmd:parse(arg)
|
||||
cutorch.setDevice(opt.gpu)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
torch.setnumthreads(opt.core)
|
||||
|
||||
local iproc = require './lib/iproc'
|
||||
local reconstruct = require './lib/reconstruct'
|
||||
local image_loader = require './lib/image_loader'
|
||||
|
||||
local MODEL_DIR = "./models/anime_style_art_rgb"
|
||||
|
||||
local noise1_model = torch.load(path.join(MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local noise2_model = torch.load(path.join(MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local scale20_model = torch.load(path.join(MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
|
||||
local USE_CACHE = true
|
||||
local CACHE_DIR = "./cache"
|
||||
if opt.thread > 0 then
|
||||
torch.setnumthreads(opt.thread)
|
||||
end
|
||||
if cudnn then
|
||||
cudnn.fastest = true
|
||||
cudnn.benchmark = false
|
||||
end
|
||||
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
|
||||
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
|
||||
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
--local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
--local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
--local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
|
||||
local CACHE_DIR = path.join(ROOT, "cache")
|
||||
local MAX_NOISE_IMAGE = 2560 * 2560
|
||||
local MAX_SCALE_IMAGE = 1280 * 1280
|
||||
local CURL_OPTIONS = {
|
||||
|
@ -40,7 +53,6 @@ local CURL_OPTIONS = {
|
|||
max_redirects = 2
|
||||
}
|
||||
local CURL_MAX_SIZE = 2 * 1024 * 1024
|
||||
local BLOCK_OFFSET = 7 -- see srcnn.lua
|
||||
|
||||
local function valid_size(x, scale)
|
||||
if scale == 0 then
|
||||
|
@ -50,20 +62,16 @@ local function valid_size(x, scale)
|
|||
end
|
||||
end
|
||||
|
||||
local function get_image(req)
|
||||
local file = req:get_argument("file", "")
|
||||
local url = req:get_argument("url", "")
|
||||
local blob = nil
|
||||
local img = nil
|
||||
local alpha = nil
|
||||
if file and file:len() > 0 then
|
||||
blob = file
|
||||
img, alpha = image_loader.decode_float(blob)
|
||||
elseif url and url:len() > 0 then
|
||||
local function cache_url(url)
|
||||
local hash = md5.sumhexa(url)
|
||||
local cache_file = path.join(CACHE_DIR, "url_" .. hash)
|
||||
if path.exists(cache_file) then
|
||||
return image_loader.load_float(cache_file)
|
||||
else
|
||||
local res = coroutine.yield(
|
||||
turbo.async.HTTPClient({verify_ca=false},
|
||||
nil,
|
||||
CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
|
||||
nil,
|
||||
CURL_MAX_SIZE):fetch(url, CURL_OPTIONS)
|
||||
)
|
||||
if res.code == 200 then
|
||||
local content_type = res.headers:get("Content-Type", true)
|
||||
|
@ -71,33 +79,64 @@ local function get_image(req)
|
|||
content_type = content_type[1]
|
||||
end
|
||||
if content_type and content_type:find("image") then
|
||||
blob = res.body
|
||||
img, alpha = image_loader.decode_float(blob)
|
||||
local fp = io.open(cache_file, "wb")
|
||||
local blob = res.body
|
||||
fp:write(blob)
|
||||
fp:close()
|
||||
return image_loader.decode_float(blob)
|
||||
end
|
||||
end
|
||||
end
|
||||
return img, blob, alpha
|
||||
return nil, nil, nil
|
||||
end
|
||||
|
||||
local function apply_denoise1(x)
|
||||
return reconstruct.image(noise1_model, x, BLOCK_OFFSET)
|
||||
local function get_image(req)
|
||||
local file = req:get_argument("file", "")
|
||||
local url = req:get_argument("url", "")
|
||||
if file and file:len() > 0 then
|
||||
return image_loader.decode_float(file)
|
||||
elseif url and url:len() > 0 then
|
||||
return cache_url(url)
|
||||
end
|
||||
return nil, nil, nil
|
||||
end
|
||||
local function apply_denoise2(x)
|
||||
return reconstruct.image(noise2_model, x, BLOCK_OFFSET)
|
||||
local function cleanup_model(model)
|
||||
if CLEANUP_MODEL then
|
||||
w2nn.cleanup_model(model) -- release GPU memory
|
||||
end
|
||||
end
|
||||
local function apply_scale2x(x)
|
||||
return reconstruct.scale(scale20_model, 2.0, x, BLOCK_OFFSET)
|
||||
end
|
||||
local function cache_do(cache, x, func)
|
||||
if path.exists(cache) then
|
||||
return image.load(cache)
|
||||
local function convert(x, options)
|
||||
local cache_file = path.join(CACHE_DIR, options.prefix .. ".png")
|
||||
if path.exists(cache_file) then
|
||||
return image.load(cache_file)
|
||||
else
|
||||
x = func(x)
|
||||
image.save(cache, x)
|
||||
if options.style == "art" then
|
||||
if options.method == "scale" then
|
||||
x = reconstruct.scale(art_scale2_model, 2.0, x)
|
||||
cleanup_model(art_scale2_model)
|
||||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(art_noise1_model, x)
|
||||
cleanup_model(art_noise1_model)
|
||||
else -- options.method == "noise2"
|
||||
x = reconstruct.image(art_noise2_model, x)
|
||||
cleanup_model(art_noise2_model)
|
||||
end
|
||||
else --[[photo
|
||||
if options.method == "scale" then
|
||||
x = reconstruct.scale(photo_scale2_model, 2.0, x)
|
||||
cleanup_model(photo_scale2_model)
|
||||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(photo_noise1_model, x)
|
||||
cleanup_model(photo_noise1_model)
|
||||
elseif options.method == "noise2" then
|
||||
x = reconstruct.image(photo_noise2_model, x)
|
||||
cleanup_model(photo_noise2_model)
|
||||
end
|
||||
--]]
|
||||
end
|
||||
image.save(cache_file, x)
|
||||
return x
|
||||
end
|
||||
end
|
||||
|
||||
local function client_disconnected(handler)
|
||||
return not(handler.request and
|
||||
handler.request.connection and
|
||||
|
@ -112,63 +151,51 @@ function APIHandler:post()
|
|||
self:write("client disconnected")
|
||||
return
|
||||
end
|
||||
local x, src, alpha = get_image(self)
|
||||
local x, alpha, blob = get_image(self)
|
||||
local scale = tonumber(self:get_argument("scale", "0"))
|
||||
local noise = tonumber(self:get_argument("noise", "0"))
|
||||
local white_noise = tonumber(self:get_argument("white_noise", "0"))
|
||||
local style = self:get_argument("style", "art")
|
||||
if style ~= "art" then
|
||||
style = "photo" -- style must be art or photo
|
||||
end
|
||||
if x and valid_size(x, scale) then
|
||||
if USE_CACHE and (noise ~= 0 or scale ~= 0) then
|
||||
local hash = md5.sumhexa(src)
|
||||
local cache_noise1 = path.join(CACHE_DIR, hash .. "_noise1.png")
|
||||
local cache_noise2 = path.join(CACHE_DIR, hash .. "_noise2.png")
|
||||
local cache_scale = path.join(CACHE_DIR, hash .. "_scale.png")
|
||||
local cache_noise1_scale = path.join(CACHE_DIR, hash .. "_noise1_scale.png")
|
||||
local cache_noise2_scale = path.join(CACHE_DIR, hash .. "_noise2_scale.png")
|
||||
|
||||
if (noise ~= 0 or scale ~= 0) then
|
||||
local hash = md5.sumhexa(blob)
|
||||
if noise == 1 then
|
||||
x = cache_do(cache_noise1, x, apply_denoise1)
|
||||
x = convert(x, {method = "noise1", style = style, prefix = style .. "_noise1_" .. hash})
|
||||
elseif noise == 2 then
|
||||
x = cache_do(cache_noise2, x, apply_denoise2)
|
||||
x = convert(x, {method = "noise2", style = style, prefix = style .. "_noise2_" .. hash})
|
||||
end
|
||||
if scale == 1 or scale == 2 then
|
||||
if noise == 1 then
|
||||
x = cache_do(cache_noise1_scale, x, apply_scale2x)
|
||||
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise1_scale_" .. hash})
|
||||
elseif noise == 2 then
|
||||
x = cache_do(cache_noise2_scale, x, apply_scale2x)
|
||||
x = convert(x, {method = "scale", style = style, prefix = style .. "_noise2_scale_" .. hash})
|
||||
else
|
||||
x = cache_do(cache_scale, x, apply_scale2x)
|
||||
x = convert(x, {method = "scale", style = style, prefix = style .. "_scale_" .. hash})
|
||||
end
|
||||
if scale == 1 then
|
||||
x = iproc.scale(x,
|
||||
math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
|
||||
math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
|
||||
"Jinc")
|
||||
x = iproc.scale_with_gamma22(x,
|
||||
math.floor(x:size(3) * (1.6 / 2.0) + 0.5),
|
||||
math.floor(x:size(2) * (1.6 / 2.0) + 0.5),
|
||||
"Jinc")
|
||||
end
|
||||
end
|
||||
elseif noise ~= 0 or scale ~= 0 then
|
||||
if noise == 1 then
|
||||
x = apply_denoise1(x)
|
||||
elseif noise == 2 then
|
||||
x = apply_denoise2(x)
|
||||
end
|
||||
if scale == 1 then
|
||||
local x16 = {math.floor(x:size(3) * 1.6 + 0.5), math.floor(x:size(2) * 1.6 + 0.5)}
|
||||
x = apply_scale2x(x)
|
||||
x = iproc.scale(x, x16[1], x16[2], "Jinc")
|
||||
elseif scale == 2 then
|
||||
x = apply_scale2x(x)
|
||||
if white_noise == 1 then
|
||||
x = iproc.white_noise(x, 0.005, {1.0, 0.8, 1.0})
|
||||
end
|
||||
end
|
||||
local name = uuid() .. ".png"
|
||||
local blob, len = image_loader.encode_png(x, alpha)
|
||||
|
||||
local blob = image_loader.encode_png(x, alpha)
|
||||
self:set_header("Content-Disposition", string.format('filename="%s"', name))
|
||||
self:set_header("Content-Type", "image/png")
|
||||
self:set_header("Content-Length", string.format("%d", len))
|
||||
self:write(ffi.string(blob, len))
|
||||
self:set_header("Content-Length", string.format("%d", #blob))
|
||||
self:write(blob)
|
||||
else
|
||||
if not x then
|
||||
self:set_status(400)
|
||||
self:write("ERROR: unsupported image format.")
|
||||
self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)")
|
||||
else
|
||||
self:set_status(400)
|
||||
self:write("ERROR: image size exceeds maximum allowable size.")
|
||||
|
@ -177,9 +204,9 @@ function APIHandler:post()
|
|||
collectgarbage()
|
||||
end
|
||||
local FormHandler = class("FormHandler", turbo.web.RequestHandler)
|
||||
local index_ja = file.read("./assets/index.ja.html")
|
||||
local index_ru = file.read("./assets/index.ru.html")
|
||||
local index_en = file.read("./assets/index.html")
|
||||
local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html"))
|
||||
local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html"))
|
||||
local index_en = file.read(path.join(ROOT, "assets", "index.html"))
|
||||
function FormHandler:get()
|
||||
local lang = self.request.headers:get("Accept-Language")
|
||||
if lang then
|
||||
|
@ -209,9 +236,11 @@ turbo.log.categories = {
|
|||
local app = turbo.web.Application:new(
|
||||
{
|
||||
{"^/$", FormHandler},
|
||||
{"^/index.html", turbo.web.StaticFileHandler, path.join("./assets", "index.html")},
|
||||
{"^/index.ja.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ja.html")},
|
||||
{"^/index.ru.html", turbo.web.StaticFileHandler, path.join("./assets", "index.ru.html")},
|
||||
{"^/style.css", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "style.css")},
|
||||
{"^/ui.js", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "ui.js")},
|
||||
{"^/index.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.html")},
|
||||
{"^/index.ja.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ja.html")},
|
||||
{"^/index.ru.html", turbo.web.StaticFileHandler, path.join(ROOT, "assets", "index.ru.html")},
|
||||
{"^/api$", APIHandler},
|
||||
}
|
||||
)
|
||||
|
|