Merge pull request #71 from nagadomi/photo
Add support for photo scaling/jpeg denoising.
This commit is contained in:
commit
e4b239ec14
12
README.md
12
README.md
|
@ -1,6 +1,7 @@
|
|||
# waifu2x
|
||||
|
||||
Image Super-Resolution for anime-style-art using Deep Convolutional Neural Networks.
|
||||
Image Super-Resolution for Anime-style art using Deep Convolutional Neural Networks.
|
||||
And it supports photo.
|
||||
|
||||
Demo-Application can be found at http://waifu2x.udp.jp/ .
|
||||
|
||||
|
@ -123,6 +124,15 @@ th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.
|
|||
|
||||
See also `th waifu2x.lua -h`.
|
||||
|
||||
### Using photo model
|
||||
|
||||
Please add `-model_dir models/photo` to command line option, if you want to use photo model.
|
||||
For example,
|
||||
|
||||
```
|
||||
th waifu2x.lua -model_dir models/photo -m scale -i input_image.png -o output_image.png
|
||||
```
|
||||
|
||||
### Video Encoding
|
||||
|
||||
\* `avconv` is alias of `ffmpeg` on Ubuntu 14.04.
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
<meta charset="UTF-8">
|
||||
<title>waifu2x</title>
|
||||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript" src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
|
||||
<script type="text/javascript" src="ui.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
|
@ -17,7 +18,7 @@
|
|||
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
|
||||
</div>
|
||||
<div class="about">
|
||||
<div>Single-Image Super-Resolution for anime/fan-art using Deep Convolutional Neural Networks. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
|
||||
<div>Single-Image Super-Resolution for Anime-Style Art using Deep Convolutional Neural Networks. And it supports photo. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
|
||||
</div>
|
||||
<form action="/api" method="POST" enctype="multipart/form-data" target="_blank">
|
||||
<fieldset>
|
||||
|
@ -32,6 +33,11 @@
|
|||
Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<legend>Style</legend>
|
||||
<label><input type="radio" name="style" value="art" checked>Art</label>
|
||||
<label><input type="radio" name="style" value="photo">Photo</label>
|
||||
</fieldset>
|
||||
<fieldset class="noise-field">
|
||||
<legend>Noise Reduction (expect JPEG Artifact)</legend>
|
||||
<label><input type="radio" name="noise" value="0"> None</label>
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<title>waifu2x</title>
|
||||
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
|
||||
<script type="text/javascript" src="ui.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
|
@ -17,7 +18,7 @@
|
|||
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
|
||||
</div>
|
||||
<div class="about">
|
||||
<div>深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
|
||||
<div>深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. 写真にも対応. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
|
||||
</div>
|
||||
<form action="/api" method="POST" enctype="multipart/form-data" target="_blank">
|
||||
<fieldset>
|
||||
|
@ -32,6 +33,11 @@
|
|||
制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<legend>スタイル</legend>
|
||||
<label><input type="radio" name="style" value="art" checked>イラスト</label>
|
||||
<label><input type="radio" name="style" value="photo">写真</label>
|
||||
</fieldset>
|
||||
<fieldset class="noise-field">
|
||||
<legend>ノイズ除去 (JPEGノイズを想定)</legend>
|
||||
<label><input type="radio" name="noise" value="0"> なし</label>
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
<title>waifu2x</title>
|
||||
<link href="style.css" rel="stylesheet" type="text/css">
|
||||
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
|
||||
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
|
||||
<script type="text/javascript" src="ui.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
|
@ -33,6 +34,11 @@
|
|||
Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<legend>Стиль</legend>
|
||||
<label><input type="radio" name="style" value="art" checked>Произведение искусства</label>
|
||||
<label><input type="radio" name="style" value="photo">фото</label>
|
||||
</fieldset>
|
||||
<fieldset class="noise-field">
|
||||
<legend>Устранение шума (артефактов JPEG)</legend>
|
||||
<label><input type="radio" name="noise" value="0"> Нет</label>
|
||||
|
|
40
assets/ui.js
40
assets/ui.js
|
@ -1,4 +1,5 @@
|
|||
$(function (){
|
||||
var expires = 365;
|
||||
function clear_file() {
|
||||
var new_file = $("#file").clone();
|
||||
new_file.change(clear_url);
|
||||
|
@ -19,6 +20,7 @@ $(function (){
|
|||
} else {
|
||||
$("h1").html("w<s>/a/</s>ifu2x");
|
||||
}
|
||||
$.cookie("style", checked.val(), {expires: expires});
|
||||
}
|
||||
function on_change_noise_level(e)
|
||||
{
|
||||
|
@ -30,6 +32,7 @@ $(function (){
|
|||
if (checked.val() != 0) {
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
}
|
||||
$.cookie("noise", checked.val(), {expires: expires});
|
||||
}
|
||||
function on_change_scale_factor(e)
|
||||
{
|
||||
|
@ -41,40 +44,29 @@ $(function (){
|
|||
if (checked.val() != 0) {
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
}
|
||||
$.cookie("scale", checked.val(), {expires: expires});
|
||||
}
|
||||
function on_change_white_noise(e)
|
||||
function restore_from_cookie()
|
||||
{
|
||||
$("input[name=white_noise]").parents("label").each(
|
||||
function (i, elm) {
|
||||
$(elm).css("font-weight", "normal");
|
||||
});
|
||||
var checked = $("input[name=white_noise]:checked");
|
||||
if (checked.val() != 0) {
|
||||
checked.parents("label").css("font-weight", "bold");
|
||||
if ($.cookie("style")) {
|
||||
$("input[name=style]").filter("[value=" + $.cookie("style") + "]").prop("checked", true)
|
||||
}
|
||||
}
|
||||
function on_click_experimental_button(e)
|
||||
{
|
||||
if ($(this).hasClass("close")) {
|
||||
$(".experimental .container").show();
|
||||
$(this).removeClass("close");
|
||||
} else {
|
||||
$(".experimental .container").hide();
|
||||
$(this).addClass("close");
|
||||
if ($.cookie("noise")) {
|
||||
$("input[name=noise]").filter("[value=" + $.cookie("noise") + "]").prop("checked", true)
|
||||
}
|
||||
if ($.cookie("scale")) {
|
||||
$("input[name=scale]").filter("[value=" + $.cookie("scale") + "]").prop("checked", true)
|
||||
}
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}
|
||||
|
||||
$("#url").change(clear_file);
|
||||
$("#file").change(clear_url);
|
||||
//$("input[name=style]").change(on_change_style);
|
||||
$("input[name=style]").change(on_change_style);
|
||||
$("input[name=noise]").change(on_change_noise_level);
|
||||
$("input[name=scale]").change(on_change_scale_factor);
|
||||
//$("input[name=white_noise]").change(on_change_white_noise);
|
||||
//$(".experimental .button").click(on_click_experimental_button)
|
||||
|
||||
//on_change_style();
|
||||
|
||||
restore_from_cookie();
|
||||
on_change_style();
|
||||
on_change_scale_factor();
|
||||
on_change_noise_level();
|
||||
})
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
require 'image'
|
||||
local iproc = require 'iproc'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
local data_augmentation = {}
|
||||
|
||||
|
@ -50,6 +51,25 @@ function data_augmentation.overlay(src, p)
|
|||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.unsharp_mask(src, p)
|
||||
if torch.uniform() < p then
|
||||
local radius = 0 -- auto
|
||||
local sigma = torch.uniform(0.5, 1.5)
|
||||
local amount = torch.uniform(0.1, 0.9)
|
||||
local threshold = torch.uniform(0.0, 0.05)
|
||||
local unsharp = gm.Image(src, "RGB", "DHW"):
|
||||
unsharpMask(radius, sigma, amount, threshold):
|
||||
toTensor("float", "RGB", "DHW")
|
||||
|
||||
if src:type() == "torch.ByteTensor" then
|
||||
return iproc.float2byte(unsharp)
|
||||
else
|
||||
return unsharp
|
||||
end
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||
local direction = torch.random(1, 4)
|
||||
|
|
|
@ -3,30 +3,32 @@ require 'cutorch'
|
|||
require 'xlua'
|
||||
|
||||
local function minibatch_adam(model, criterion,
|
||||
train_x,
|
||||
config, transformer,
|
||||
input_size, target_size)
|
||||
train_x, train_y,
|
||||
config)
|
||||
local parameters, gradParameters = model:getParameters()
|
||||
config = config or {}
|
||||
local sum_loss = 0
|
||||
local count_loss = 0
|
||||
local batch_size = config.xBatchSize or 32
|
||||
local shuffle = torch.randperm(#train_x)
|
||||
local shuffle = torch.randperm(train_x:size(1))
|
||||
local c = 1
|
||||
local inputs = torch.Tensor(batch_size,
|
||||
input_size[1], input_size[2], input_size[3]):cuda()
|
||||
local targets = torch.Tensor(batch_size,
|
||||
target_size[1] * target_size[2] * target_size[3]):cuda()
|
||||
local inputs_tmp = torch.Tensor(batch_size,
|
||||
input_size[1], input_size[2], input_size[3])
|
||||
train_x:size(2), train_x:size(3), train_x:size(4)):zero()
|
||||
local targets_tmp = torch.Tensor(batch_size,
|
||||
target_size[1] * target_size[2] * target_size[3])
|
||||
for t = 1, #train_x do
|
||||
xlua.progress(t, #train_x)
|
||||
local xy = transformer(train_x[shuffle[t]], false, batch_size)
|
||||
for i = 1, #xy do
|
||||
inputs_tmp[i]:copy(xy[i][1])
|
||||
targets_tmp[i]:copy(xy[i][2])
|
||||
train_y:size(2)):zero()
|
||||
local inputs = inputs_tmp:clone():cuda()
|
||||
local targets = targets_tmp:clone():cuda()
|
||||
|
||||
print("## update")
|
||||
for t = 1, train_x:size(1), batch_size do
|
||||
if t + batch_size -1 > train_x:size(1) then
|
||||
break
|
||||
end
|
||||
xlua.progress(t, train_x:size(1))
|
||||
|
||||
for i = 1, batch_size do
|
||||
inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
|
||||
targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
|
||||
end
|
||||
inputs:copy(inputs_tmp)
|
||||
targets:copy(targets_tmp)
|
||||
|
@ -43,13 +45,12 @@ local function minibatch_adam(model, criterion,
|
|||
return f, gradParameters
|
||||
end
|
||||
optim.adam(feval, parameters, config)
|
||||
|
||||
c = c + 1
|
||||
if c % 20 == 0 then
|
||||
if c % 50 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
xlua.progress(#train_x, #train_x)
|
||||
xlua.progress(train_x:size(1), train_x:size(1))
|
||||
|
||||
return { loss = sum_loss / count_loss}
|
||||
end
|
||||
|
|
|
@ -7,7 +7,7 @@ local pairwise_transform = {}
|
|||
|
||||
local function random_half(src, p)
|
||||
if torch.uniform() < p then
|
||||
local filter = ({"Box","Box","Blackman","Sinc","Lanczos"})[torch.random(1, 5)]
|
||||
local filter = ({"Box","Box","Blackman","Sinc","Lanczos", "Catrom"})[torch.random(1, 6)]
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
return src
|
||||
|
@ -38,6 +38,7 @@ local function preprocess(src, crop_size, options)
|
|||
dest = data_augmentation.flip(dest)
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
|
||||
return dest
|
||||
|
@ -45,6 +46,10 @@ end
|
|||
local function active_cropping(x, y, size, p, tries)
|
||||
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
local r = torch.uniform()
|
||||
local t = "float"
|
||||
if x:type() == "torch.ByteTensor" then
|
||||
t = "byte"
|
||||
end
|
||||
if p < r then
|
||||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
|
@ -52,6 +57,10 @@ local function active_cropping(x, y, size, p, tries)
|
|||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
return xc, yc
|
||||
else
|
||||
local lowres = gm.Image(x, "RGB", "DHW"):
|
||||
size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
|
||||
size(x:size(3), x:size(2), "Box"):
|
||||
toTensor(t, "RGB", "DHW")
|
||||
local best_se = 0.0
|
||||
local best_xc, best_yc
|
||||
local m = torch.FloatTensor(x:size(1), size, size)
|
||||
|
@ -59,13 +68,13 @@ local function active_cropping(x, y, size, p, tries)
|
|||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
|
||||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
|
||||
local xcf = iproc.byte2float(xc)
|
||||
local ycf = iproc.byte2float(yc)
|
||||
local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
|
||||
local lcf = iproc.byte2float(lc)
|
||||
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
|
||||
if se >= best_se then
|
||||
best_xc = xcf
|
||||
best_yc = ycf
|
||||
best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
|
||||
best_se = se
|
||||
end
|
||||
end
|
||||
|
@ -73,15 +82,23 @@ local function active_cropping(x, y, size, p, tries)
|
|||
end
|
||||
end
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters = {
|
||||
"Box","Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Cartom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"Sinc", -- 0.014095824314306
|
||||
"Lanczos", -- 0.014244299255442
|
||||
}
|
||||
local filters;
|
||||
|
||||
if options.style == "photo" then
|
||||
filters = {
|
||||
"Box", "lanczos", "Catrom"
|
||||
}
|
||||
else
|
||||
filters = {
|
||||
"Box","Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Catrom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"Sinc", -- 0.014095824314306
|
||||
"Lanczos", -- 0.014244299255442
|
||||
}
|
||||
end
|
||||
local unstable_region_offset = 8
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
local y = preprocess(src, size, options)
|
||||
|
@ -122,10 +139,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg"):depth(8)
|
||||
if options.jpeg_sampling_factors == 444 then
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
else -- 420
|
||||
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
|
||||
-- YUV 420
|
||||
x:samplingFactors({2.0, 1.0, 1.0})
|
||||
else
|
||||
-- YUV 444
|
||||
x:samplingFactors({1.0, 1.0, 1.0})
|
||||
end
|
||||
local blob, len = x:toBlob(quality[i])
|
||||
x:fromBlob(blob, len)
|
||||
|
@ -188,23 +207,10 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
|||
error("unknown noise level: " .. level)
|
||||
end
|
||||
elseif style == "photo" then
|
||||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 75)},
|
||||
size, offset, n,
|
||||
options)
|
||||
elseif level == 2 then
|
||||
if torch.uniform() > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 60)},
|
||||
size, offset, n, options)
|
||||
else
|
||||
local quality1 = torch.random(40, 60)
|
||||
local quality2 = quality1 - torch.random(5, 10)
|
||||
return pairwise_transform.jpeg_(src, {quality1, quality2},
|
||||
size, offset, n, options)
|
||||
end
|
||||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
-- level adjusting by -nr_rate
|
||||
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
|
||||
size, offset, n,
|
||||
options)
|
||||
else
|
||||
error("unknown style: " .. style)
|
||||
end
|
||||
|
@ -215,6 +221,8 @@ function pairwise_transform.test_jpeg(src)
|
|||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
random_unsharp_mask_rate = 0.5,
|
||||
jpeg_chroma_subsampling_rate = 0.5,
|
||||
nr_rate = 1.0,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
|
@ -237,6 +245,7 @@ function pairwise_transform.test_scale(src)
|
|||
local options = {random_color_noise_rate = 0.5,
|
||||
random_half_rate = 0.5,
|
||||
random_overlay_rate = 0.5,
|
||||
random_unsharp_mask_rate = 0.5,
|
||||
active_cropping_rate = 0.5,
|
||||
active_cropping_tries = 10,
|
||||
max_size = 256,
|
||||
|
|
|
@ -30,35 +30,53 @@ cmd:option("-color", 'rgb', '(y|rgb)')
|
|||
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
|
||||
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
||||
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
||||
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-learning_rate", 0.001, 'learning rate for adam')
|
||||
cmd:option("-crop_size", 46, 'crop size')
|
||||
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
||||
cmd:option("-batch_size", 8, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'number of total epochs to run')
|
||||
cmd:option("-batch_size", 32, 'mini batch size')
|
||||
cmd:option("-patches", 16, 'number of patch samples')
|
||||
cmd:option("-inner_epoch", 4, 'number of inner epochs')
|
||||
cmd:option("-epoch", 30, 'number of epochs to run')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
||||
cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
|
||||
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
|
||||
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
|
||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||
cmd:option("-save_history", 0, 'save all model (0|1)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
settings[k] = v
|
||||
end
|
||||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
elseif settings.method == "noise_scale" then
|
||||
settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.noise_level, settings.scale)
|
||||
if settings.save_history == 1 then
|
||||
settings.save_history = true
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
settings.save_history = false
|
||||
end
|
||||
if settings.save_history then
|
||||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
end
|
||||
else
|
||||
if settings.method == "noise" then
|
||||
settings.model_file = string.format("%s/noise%d_model.t7",
|
||||
settings.model_dir, settings.noise_level)
|
||||
elseif settings.method == "scale" then
|
||||
settings.model_file = string.format("%s/scale%.1fx_model.t7",
|
||||
settings.model_dir, settings.scale)
|
||||
else
|
||||
error("unknown method: " .. settings.method)
|
||||
end
|
||||
end
|
||||
if not (settings.color == "rgb" or settings.color == "y") then
|
||||
error("color must be y or rgb")
|
||||
|
|
BIN
models/photo/noise1_model.json
Normal file
BIN
models/photo/noise1_model.json
Normal file
Binary file not shown.
BIN
models/photo/noise1_model.t7
Normal file
BIN
models/photo/noise1_model.t7
Normal file
Binary file not shown.
BIN
models/photo/noise2_model.json
Normal file
BIN
models/photo/noise2_model.json
Normal file
Binary file not shown.
BIN
models/photo/noise2_model.t7
Normal file
BIN
models/photo/noise2_model.t7
Normal file
Binary file not shown.
BIN
models/photo/scale2.0x_model.json
Normal file
BIN
models/photo/scale2.0x_model.json
Normal file
Binary file not shown.
BIN
models/photo/scale2.0x_model.t7
Normal file
BIN
models/photo/scale2.0x_model.t7
Normal file
Binary file not shown.
|
@ -23,6 +23,7 @@ cmd:option("-noise_level", 1, 'model noise level')
|
|||
cmd:option("-jpeg_quality", 75, 'jpeg quality')
|
||||
cmd:option("-jpeg_times", 1, 'jpeg compression times')
|
||||
cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
|
||||
cmd:option("-range_bug", 0, 'Reproducing the dynamic range bug that is caused by MATLAB\'s rgb2ycbcr(1|0)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
torch.setdefaulttensortype('torch.FloatTensor')
|
||||
|
@ -41,25 +42,33 @@ local function rgb2y_matlab(x)
|
|||
return y:byte():float()
|
||||
end
|
||||
|
||||
local function MSE(x1, x2)
|
||||
local function RGBMSE(x1, x2)
|
||||
x1 = iproc.float2byte(x1):float()
|
||||
x2 = iproc.float2byte(x2):float()
|
||||
return (x1 - x2):pow(2):mean()
|
||||
end
|
||||
local function YMSE(x1, x2)
|
||||
local x1_2 = rgb2y_matlab(x1)
|
||||
local x2_2 = rgb2y_matlab(x2)
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
if opt.range_bug == 1 then
|
||||
local x1_2 = rgb2y_matlab(x1)
|
||||
local x2_2 = rgb2y_matlab(x2)
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
else
|
||||
local x1_2 = image.rgb2y(x1):mul(255.0)
|
||||
local x2_2 = image.rgb2y(x2):mul(255.0)
|
||||
return (x1_2 - x2_2):pow(2):mean()
|
||||
end
|
||||
end
|
||||
local function PSNR(x1, x2)
|
||||
local mse = MSE(x1, x2)
|
||||
local function MSE(x1, x2, color)
|
||||
if color == "y" then
|
||||
return YMSE(x1, x2)
|
||||
else
|
||||
return RGBMSE(x1, x2)
|
||||
end
|
||||
end
|
||||
local function PSNR(x1, x2, color)
|
||||
local mse = MSE(x1, x2, color)
|
||||
return 10 * math.log10((255.0 * 255.0) / mse)
|
||||
end
|
||||
local function YPSNR(x1, x2)
|
||||
local mse = YMSE(x1, x2)
|
||||
return 10 * math.log10((255.0 * 255.0) / mse)
|
||||
end
|
||||
|
||||
local function transform_jpeg(x, opt)
|
||||
for i = 1, opt.jpeg_times do
|
||||
jpeg = gm.Image(x, "RGB", "DHW")
|
||||
|
@ -69,7 +78,7 @@ local function transform_jpeg(x, opt)
|
|||
jpeg:fromBlob(blob, len)
|
||||
x = jpeg:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
return x
|
||||
return iproc.byte2float(x)
|
||||
end
|
||||
local function baseline_scale(x, filter)
|
||||
return iproc.scale(x,
|
||||
|
@ -110,62 +119,47 @@ local function benchmark(opt, x, input_func, model1, model2)
|
|||
end
|
||||
baseline_output = baseline_scale(input, opt.filter)
|
||||
end
|
||||
if opt.color == "y" then
|
||||
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
|
||||
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
|
||||
end
|
||||
if baseline_output then
|
||||
baseline_mse = baseline_mse + YMSE(ground_truth, baseline_output)
|
||||
baseline_psnr = baseline_psnr + YPSNR(ground_truth, baseline_output)
|
||||
end
|
||||
elseif opt.color == "rgb" then
|
||||
model1_mse = model1_mse + MSE(ground_truth, model1_output)
|
||||
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + MSE(ground_truth, model2_output)
|
||||
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
|
||||
end
|
||||
if baseline_output then
|
||||
baseline_mse = baseline_mse + MSE(ground_truth, baseline_output)
|
||||
baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output)
|
||||
end
|
||||
else
|
||||
error("Unknown color: " .. opt.color)
|
||||
model1_mse = model1_mse + MSE(ground_truth, model1_output, opt.color)
|
||||
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output, opt.color)
|
||||
if model2 then
|
||||
model2_mse = model2_mse + MSE(ground_truth, model2_output, opt.color)
|
||||
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output, opt.color)
|
||||
end
|
||||
if baseline_output then
|
||||
baseline_mse = baseline_mse + MSE(ground_truth, baseline_output, opt.color)
|
||||
baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output, opt.color)
|
||||
end
|
||||
if model2 then
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; baseline_mse=%f, model1_mse=%f, model2_mse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
i, #x,
|
||||
baseline_mse / i,
|
||||
model1_mse / i, model2_mse / i,
|
||||
math.sqrt(baseline_mse / i),
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
baseline_psnr / i,
|
||||
model1_psnr / i, model2_psnr / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
string.format("%d/%d; model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model2_mse / i,
|
||||
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
|
||||
model1_psnr / i, model2_psnr / i
|
||||
))
|
||||
end
|
||||
else
|
||||
if baseline_output then
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; baseline_mse=%f, model1_mse=%f, baseline_psnr=%f, model1_psnr=%f \r",
|
||||
string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
|
||||
i, #x,
|
||||
baseline_mse / i, model1_mse / i,
|
||||
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
|
||||
baseline_psnr / i, model1_psnr / i
|
||||
))
|
||||
else
|
||||
io.stdout:write(
|
||||
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
|
||||
string.format("%d/%d; model1_rmse=%f, model1_psnr=%f \r",
|
||||
i, #x,
|
||||
model1_mse / i, model1_psnr / i
|
||||
math.sqrt(model1_mse / i), model1_psnr / i
|
||||
))
|
||||
end
|
||||
end
|
||||
|
|
120
train.lua
120
train.lua
|
@ -35,14 +35,14 @@ local function split_data(x, test_size)
|
|||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
local function make_validation_set(x, transformer, n, batch_size)
|
||||
local function make_validation_set(x, transformer, n, patches)
|
||||
n = n or 4
|
||||
local data = {}
|
||||
for i = 1, #x do
|
||||
for k = 1, math.max(n / batch_size, 1) do
|
||||
local xy = transformer(x[i], true, batch_size)
|
||||
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
||||
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
||||
for k = 1, math.max(n / patches, 1) do
|
||||
local xy = transformer(x[i], true, patches)
|
||||
local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
|
||||
local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
|
||||
for j = 1, #xy do
|
||||
tx[j]:copy(xy[j][1])
|
||||
ty[j]:copy(xy[j][2])
|
||||
|
@ -83,7 +83,8 @@ local function create_criterion(model)
|
|||
end
|
||||
local function transformer(x, is_validation, n, offset)
|
||||
x = compression.decompress(x)
|
||||
n = n or settings.batch_size;
|
||||
n = n or settings.patches
|
||||
|
||||
if is_validation == nil then is_validation = false end
|
||||
local random_color_noise_rate = nil
|
||||
local random_overlay_rate = nil
|
||||
|
@ -110,6 +111,7 @@ local function transformer(x, is_validation, n, offset)
|
|||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
|
@ -125,8 +127,9 @@ local function transformer(x, is_validation, n, offset)
|
|||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
nr_rate = settings.nr_rate,
|
||||
|
@ -135,7 +138,24 @@ local function transformer(x, is_validation, n, offset)
|
|||
end
|
||||
end
|
||||
|
||||
local function resampling(x, y, train_x, transformer, input_size, target_size)
|
||||
print("## resampling")
|
||||
for t = 1, #train_x do
|
||||
xlua.progress(t, #train_x)
|
||||
local xy = transformer(train_x[t], false, settings.patches)
|
||||
for i = 1, #xy do
|
||||
local index = (t - 1) * settings.patches + i
|
||||
x[index]:copy(xy[i][1])
|
||||
y[index]:copy(xy[i][2])
|
||||
end
|
||||
if t % 50 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function train()
|
||||
local LR_MIN = 1.0e-5
|
||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local pairwise_func = function(x, is_validation, n)
|
||||
|
@ -143,12 +163,12 @@ local function train()
|
|||
end
|
||||
local criterion = create_criterion(model)
|
||||
local x = torch.load(settings.images)
|
||||
local lrd_count = 0
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||
local adam_config = {
|
||||
learningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
}
|
||||
local lrd_count = 0
|
||||
local ch = nil
|
||||
if settings.color == "y" then
|
||||
ch = 1
|
||||
|
@ -159,48 +179,70 @@ local function train()
|
|||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||
settings.validation_crops,
|
||||
settings.batch_size)
|
||||
settings.patches)
|
||||
valid_x = nil
|
||||
|
||||
collectgarbage()
|
||||
model:cuda()
|
||||
print("load .. " .. #train_x)
|
||||
|
||||
local x = torch.Tensor(settings.patches * #train_x,
|
||||
ch, settings.crop_size, settings.crop_size)
|
||||
local y = torch.Tensor(settings.patches * #train_x,
|
||||
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
|
||||
|
||||
for epoch = 1, settings.epoch do
|
||||
model:training()
|
||||
print("# " .. epoch)
|
||||
print(minibatch_adam(model, criterion, train_x, adam_config,
|
||||
pairwise_func,
|
||||
{ch, settings.crop_size, settings.crop_size},
|
||||
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
|
||||
))
|
||||
model:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, criterion, valid_xy)
|
||||
if score < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
best_score = score
|
||||
print("* update best model")
|
||||
torch.save(settings.model_file, model)
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.png"):format(settings.scale))
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
else
|
||||
lrd_count = lrd_count + 1
|
||||
if lrd_count > 5 then
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
for i = 1, settings.inner_epoch do
|
||||
print(minibatch_adam(model, criterion, x, y, adam_config))
|
||||
model:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, criterion, valid_xy)
|
||||
if score < best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
adam_config.learningRate = adam_config.learningRate * 0.9
|
||||
print("* learning rate decay: " .. adam_config.learningRate)
|
||||
best_score = score
|
||||
print("* update best model")
|
||||
if settings.save_history then
|
||||
local model_clone = model:clone()
|
||||
w2nn.cleanup_model(model_clone)
|
||||
torch.save(string.format(settings.model_file, epoch, i), model_clone)
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
||||
epoch, i))
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.%d-%d.png"):format(settings.scale,
|
||||
epoch, i))
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
else
|
||||
torch.save(settings.model_file, model)
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
save_test_jpeg(model, test_image, log)
|
||||
elseif settings.method == "scale" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("scale%.1f_best.png"):format(settings.scale))
|
||||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
end
|
||||
else
|
||||
lrd_count = lrd_count + 1
|
||||
if lrd_count > 2 and adam_config.learningRate > LR_MIN then
|
||||
adam_config.learningRate = adam_config.learningRate * 0.8
|
||||
print("* learning rate decay: " .. adam_config.learningRate)
|
||||
lrd_count = 0
|
||||
end
|
||||
end
|
||||
print("current: " .. score .. ", best: " .. best_score)
|
||||
collectgarbage()
|
||||
end
|
||||
print("current: " .. score .. ", best: " .. best_score)
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
if settings.gpu > 0 then
|
||||
|
|
14
web.lua
14
web.lua
|
@ -12,8 +12,9 @@ local iproc = require 'iproc'
|
|||
local reconstruct = require 'reconstruct'
|
||||
local image_loader = require 'image_loader'
|
||||
local alpha_util = require 'alpha_util'
|
||||
local gm = require 'graphicsmagick'
|
||||
|
||||
-- Notes: turbo and xlua has different implementation of string:split().
|
||||
-- Note: turbo and xlua has different implementation of string:split().
|
||||
-- Therefore, string:split() has conflict issue.
|
||||
-- In this script, use turbo's string:split().
|
||||
local turbo = require 'turbo'
|
||||
|
@ -36,13 +37,13 @@ if cudnn then
|
|||
cudnn.benchmark = false
|
||||
end
|
||||
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
|
||||
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
|
||||
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
|
||||
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
--local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
--local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
--local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
|
||||
local CACHE_DIR = path.join(ROOT, "cache")
|
||||
local MAX_NOISE_IMAGE = 2560 * 2560
|
||||
|
@ -143,7 +144,7 @@ local function convert(x, alpha, options)
|
|||
x = reconstruct.image(art_noise2_model, x)
|
||||
cleanup_model(art_noise2_model)
|
||||
end
|
||||
else --[[photo
|
||||
else -- photo
|
||||
if options.border then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
|
||||
end
|
||||
|
@ -163,7 +164,6 @@ local function convert(x, alpha, options)
|
|||
x = reconstruct.image(photo_noise2_model, x)
|
||||
cleanup_model(photo_noise2_model)
|
||||
end
|
||||
--]]
|
||||
end
|
||||
image_loader.save_png(cache_file, x)
|
||||
|
||||
|
|
Loading…
Reference in a new issue