1
0
Fork 0
mirror of synced 2024-06-26 10:10:49 +12:00

Merge pull request #71 from nagadomi/photo

Add support for photo scaling/jpeg denoising.
This commit is contained in:
nagadomi 2015-12-04 22:57:41 +09:00
commit e4b239ec14
18 changed files with 290 additions and 186 deletions

View file

@ -1,6 +1,7 @@
# waifu2x
Image Super-Resolution for anime-style-art using Deep Convolutional Neural Networks.
Image Super-Resolution for Anime-style art using Deep Convolutional Neural Networks.
And it supports photo.
Demo-Application can be found at http://waifu2x.udp.jp/ .
@ -123,6 +124,15 @@ th waifu2x.lua -m noise_scale -noise_level 2 -i input_image.png -o output_image.
See also `th waifu2x.lua -h`.
### Using photo model
Please add `-model_dir models/photo` to command line option, if you want to use photo model.
For example,
```
th waifu2x.lua -model_dir models/photo -m scale -i input_image.png -o output_image.png
```
### Video Encoding
\* `avconv` is alias of `ffmpeg` on Ubuntu 14.04.

View file

@ -4,7 +4,8 @@
<meta charset="UTF-8">
<title>waifu2x</title>
<link href="style.css" rel="stylesheet" type="text/css">
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript" src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
<script type="text/javascript" src="ui.js"></script>
</head>
<body>
@ -17,7 +18,7 @@
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
</div>
<div class="about">
<div>Single-Image Super-Resolution for anime/fan-art using Deep Convolutional Neural Networks. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
<div>Single-Image Super-Resolution for Anime-Style Art using Deep Convolutional Neural Networks. And it supports photo. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
</div>
<form action="/api" method="POST" enctype="multipart/form-data" target="_blank">
<fieldset>
@ -32,6 +33,11 @@
Limits: Size: 2MB, Noise Reduction: 2560x2560px, Upscaling: 1280x1280px
</div>
</fieldset>
<fieldset>
<legend>Style</legend>
<label><input type="radio" name="style" value="art" checked>Art</label>
<label><input type="radio" name="style" value="photo">Photo</label>
</fieldset>
<fieldset class="noise-field">
<legend>Noise Reduction (expect JPEG Artifact)</legend>
<label><input type="radio" name="noise" value="0"> None</label>

View file

@ -5,6 +5,7 @@
<link href="style.css" rel="stylesheet" type="text/css">
<title>waifu2x</title>
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
<script type="text/javascript" src="ui.js"></script>
</head>
<body>
@ -17,7 +18,7 @@
<a href="index.html">en</a>/<a href="index.ja.html">ja</a>/<a href="index.ru.html">ru</a>
</div>
<div class="about">
<div>深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
<div>深層畳み込みニューラルネットワークによる二次元画像のための超解像システム. 写真にも対応. <a href="https://raw.githubusercontent.com/nagadomi/waifu2x/master/images/slide.png" target="_blank">about</a>.</div>
</div>
<form action="/api" method="POST" enctype="multipart/form-data" target="_blank">
<fieldset>
@ -32,6 +33,11 @@
制限: サイズ: 2MB, ノイズ除去: 2560x2560px, 拡大: 1280x1280px
</div>
</fieldset>
<fieldset>
<legend>スタイル</legend>
<label><input type="radio" name="style" value="art" checked>イラスト</label>
<label><input type="radio" name="style" value="photo">写真</label>
</fieldset>
<fieldset class="noise-field">
<legend>ノイズ除去 (JPEGイズを想定)</legend>
<label><input type="radio" name="noise" value="0"> なし</label>

View file

@ -6,6 +6,7 @@
<title>waifu2x</title>
<link href="style.css" rel="stylesheet" type="text/css">
<script type="text/javascript" src="http://ajax.googleapis.com/ajax/libs/jquery/2.1.3/jquery.min.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/jquery-cookie/1.4.1/jquery.cookie.js"></script>
<script type="text/javascript" src="ui.js"></script>
</head>
<body>
@ -33,6 +34,11 @@
Макс. размер файла — 2MB, устранение шума — макс. 2560x2560px, апскейл — 1280x1280px
</div>
</fieldset>
<fieldset>
<legend>Стиль</legend>
<label><input type="radio" name="style" value="art" checked>Произведение искусства</label>
<label><input type="radio" name="style" value="photo">фото</label>
</fieldset>
<fieldset class="noise-field">
<legend>Устранение шума (артефактов JPEG)</legend>
<label><input type="radio" name="noise" value="0"> Нет</label>

View file

@ -1,4 +1,5 @@
$(function (){
var expires = 365;
function clear_file() {
var new_file = $("#file").clone();
new_file.change(clear_url);
@ -19,6 +20,7 @@ $(function (){
} else {
$("h1").html("w<s>/a/</s>ifu2x");
}
$.cookie("style", checked.val(), {expires: expires});
}
function on_change_noise_level(e)
{
@ -30,6 +32,7 @@ $(function (){
if (checked.val() != 0) {
checked.parents("label").css("font-weight", "bold");
}
$.cookie("noise", checked.val(), {expires: expires});
}
function on_change_scale_factor(e)
{
@ -41,40 +44,29 @@ $(function (){
if (checked.val() != 0) {
checked.parents("label").css("font-weight", "bold");
}
$.cookie("scale", checked.val(), {expires: expires});
}
function on_change_white_noise(e)
function restore_from_cookie()
{
$("input[name=white_noise]").parents("label").each(
function (i, elm) {
$(elm).css("font-weight", "normal");
});
var checked = $("input[name=white_noise]:checked");
if (checked.val() != 0) {
checked.parents("label").css("font-weight", "bold");
if ($.cookie("style")) {
$("input[name=style]").filter("[value=" + $.cookie("style") + "]").prop("checked", true)
}
}
function on_click_experimental_button(e)
{
if ($(this).hasClass("close")) {
$(".experimental .container").show();
$(this).removeClass("close");
} else {
$(".experimental .container").hide();
$(this).addClass("close");
if ($.cookie("noise")) {
$("input[name=noise]").filter("[value=" + $.cookie("noise") + "]").prop("checked", true)
}
if ($.cookie("scale")) {
$("input[name=scale]").filter("[value=" + $.cookie("scale") + "]").prop("checked", true)
}
e.preventDefault();
e.stopPropagation();
}
$("#url").change(clear_file);
$("#file").change(clear_url);
//$("input[name=style]").change(on_change_style);
$("input[name=style]").change(on_change_style);
$("input[name=noise]").change(on_change_noise_level);
$("input[name=scale]").change(on_change_scale_factor);
//$("input[name=white_noise]").change(on_change_white_noise);
//$(".experimental .button").click(on_click_experimental_button)
//on_change_style();
restore_from_cookie();
on_change_style();
on_change_scale_factor();
on_change_noise_level();
})

View file

@ -1,5 +1,6 @@
require 'image'
local iproc = require 'iproc'
local gm = require 'graphicsmagick'
local data_augmentation = {}
@ -50,6 +51,25 @@ function data_augmentation.overlay(src, p)
return src
end
end
function data_augmentation.unsharp_mask(src, p)
if torch.uniform() < p then
local radius = 0 -- auto
local sigma = torch.uniform(0.5, 1.5)
local amount = torch.uniform(0.1, 0.9)
local threshold = torch.uniform(0.0, 0.05)
local unsharp = gm.Image(src, "RGB", "DHW"):
unsharpMask(radius, sigma, amount, threshold):
toTensor("float", "RGB", "DHW")
if src:type() == "torch.ByteTensor" then
return iproc.float2byte(unsharp)
else
return unsharp
end
else
return src
end
end
function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler.
local direction = torch.random(1, 4)

View file

@ -3,30 +3,32 @@ require 'cutorch'
require 'xlua'
local function minibatch_adam(model, criterion,
train_x,
config, transformer,
input_size, target_size)
train_x, train_y,
config)
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
local count_loss = 0
local batch_size = config.xBatchSize or 32
local shuffle = torch.randperm(#train_x)
local shuffle = torch.randperm(train_x:size(1))
local c = 1
local inputs = torch.Tensor(batch_size,
input_size[1], input_size[2], input_size[3]):cuda()
local targets = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3]):cuda()
local inputs_tmp = torch.Tensor(batch_size,
input_size[1], input_size[2], input_size[3])
train_x:size(2), train_x:size(3), train_x:size(4)):zero()
local targets_tmp = torch.Tensor(batch_size,
target_size[1] * target_size[2] * target_size[3])
for t = 1, #train_x do
xlua.progress(t, #train_x)
local xy = transformer(train_x[shuffle[t]], false, batch_size)
for i = 1, #xy do
inputs_tmp[i]:copy(xy[i][1])
targets_tmp[i]:copy(xy[i][2])
train_y:size(2)):zero()
local inputs = inputs_tmp:clone():cuda()
local targets = targets_tmp:clone():cuda()
print("## update")
for t = 1, train_x:size(1), batch_size do
if t + batch_size -1 > train_x:size(1) then
break
end
xlua.progress(t, train_x:size(1))
for i = 1, batch_size do
inputs_tmp[i]:copy(train_x[shuffle[t + i - 1]])
targets_tmp[i]:copy(train_y[shuffle[t + i - 1]])
end
inputs:copy(inputs_tmp)
targets:copy(targets_tmp)
@ -43,13 +45,12 @@ local function minibatch_adam(model, criterion,
return f, gradParameters
end
optim.adam(feval, parameters, config)
c = c + 1
if c % 20 == 0 then
if c % 50 == 0 then
collectgarbage()
end
end
xlua.progress(#train_x, #train_x)
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss}
end

View file

@ -7,7 +7,7 @@ local pairwise_transform = {}
local function random_half(src, p)
if torch.uniform() < p then
local filter = ({"Box","Box","Blackman","Sinc","Lanczos"})[torch.random(1, 5)]
local filter = ({"Box","Box","Blackman","Sinc","Lanczos", "Catrom"})[torch.random(1, 6)]
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
return src
@ -38,6 +38,7 @@ local function preprocess(src, crop_size, options)
dest = data_augmentation.flip(dest)
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
dest = data_augmentation.unsharp_mask(dest, options.random_unsharp_mask_rate)
dest = data_augmentation.shift_1px(dest)
return dest
@ -45,6 +46,10 @@ end
local function active_cropping(x, y, size, p, tries)
assert("x:size == y:size", x:size(2) == y:size(2) and x:size(3) == y:size(3))
local r = torch.uniform()
local t = "float"
if x:type() == "torch.ByteTensor" then
t = "byte"
end
if p < r then
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
@ -52,6 +57,10 @@ local function active_cropping(x, y, size, p, tries)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
return xc, yc
else
local lowres = gm.Image(x, "RGB", "DHW"):
size(x:size(3) * 0.5, x:size(2) * 0.5, "Box"):
size(x:size(3), x:size(2), "Box"):
toTensor(t, "RGB", "DHW")
local best_se = 0.0
local best_xc, best_yc
local m = torch.FloatTensor(x:size(1), size, size)
@ -59,13 +68,13 @@ local function active_cropping(x, y, size, p, tries)
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
local xc = iproc.crop(x, xi, yi, xi + size, yi + size)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local lc = iproc.crop(lowres, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local ycf = iproc.byte2float(yc)
local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
local lcf = iproc.byte2float(lc)
local se = m:copy(xcf):add(-1.0, lcf):pow(2):sum()
if se >= best_se then
best_xc = xcf
best_yc = ycf
best_yc = iproc.byte2float(iproc.crop(y, xi, yi, xi + size, yi + size))
best_se = se
end
end
@ -73,15 +82,23 @@ local function active_cropping(x, y, size, p, tries)
end
end
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters = {
"Box","Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Cartom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"Sinc", -- 0.014095824314306
"Lanczos", -- 0.014244299255442
}
local filters;
if options.style == "photo" then
filters = {
"Box", "lanczos", "Catrom"
}
else
filters = {
"Box","Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Catrom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"Sinc", -- 0.014095824314306
"Lanczos", -- 0.014244299255442
}
end
local unstable_region_offset = 8
local downscale_filter = filters[torch.random(1, #filters)]
local y = preprocess(src, size, options)
@ -122,10 +139,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg"):depth(8)
if options.jpeg_sampling_factors == 444 then
x:samplingFactors({1.0, 1.0, 1.0})
else -- 420
if torch.uniform() < options.jpeg_chroma_subsampling_rate then
-- YUV 420
x:samplingFactors({2.0, 1.0, 1.0})
else
-- YUV 444
x:samplingFactors({1.0, 1.0, 1.0})
end
local blob, len = x:toBlob(quality[i])
x:fromBlob(blob, len)
@ -188,23 +207,10 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
error("unknown noise level: " .. level)
end
elseif style == "photo" then
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(30, 75)},
size, offset, n,
options)
elseif level == 2 then
if torch.uniform() > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(30, 60)},
size, offset, n, options)
else
local quality1 = torch.random(40, 60)
local quality2 = quality1 - torch.random(5, 10)
return pairwise_transform.jpeg_(src, {quality1, quality2},
size, offset, n, options)
end
else
error("unknown noise level: " .. level)
end
-- level adjusting by -nr_rate
return pairwise_transform.jpeg_(src, {torch.random(30, 70)},
size, offset, n,
options)
else
error("unknown style: " .. style)
end
@ -215,6 +221,8 @@ function pairwise_transform.test_jpeg(src)
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
jpeg_chroma_subsampling_rate = 0.5,
nr_rate = 1.0,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
@ -237,6 +245,7 @@ function pairwise_transform.test_scale(src)
local options = {random_color_noise_rate = 0.5,
random_half_rate = 0.5,
random_overlay_rate = 0.5,
random_unsharp_mask_rate = 0.5,
active_cropping_rate = 0.5,
active_cropping_tries = 10,
max_size = 256,

View file

@ -30,35 +30,53 @@ cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-learning_rate", 0.001, 'learning rate for adam')
cmd:option("-crop_size", 46, 'crop size')
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
cmd:option("-batch_size", 8, 'mini batch size')
cmd:option("-epoch", 200, 'number of total epochs to run')
cmd:option("-batch_size", 32, 'mini batch size')
cmd:option("-patches", 16, 'number of patch samples')
cmd:option("-inner_epoch", 4, 'number of inner epochs')
cmd:option("-epoch", 30, 'number of epochs to run')
cmd:option("-thread", -1, 'number of CPU threads')
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
cmd:option("-save_history", 0, 'save all model (0|1)')
local opt = cmd:parse(arg)
for k, v in pairs(opt) do
settings[k] = v
end
if settings.method == "noise" then
settings.model_file = string.format("%s/noise%d_model.t7",
settings.model_dir, settings.noise_level)
elseif settings.method == "scale" then
settings.model_file = string.format("%s/scale%.1fx_model.t7",
settings.model_dir, settings.scale)
elseif settings.method == "noise_scale" then
settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
settings.model_dir, settings.noise_level, settings.scale)
if settings.save_history == 1 then
settings.save_history = true
else
error("unknown method: " .. settings.method)
settings.save_history = false
end
if settings.save_history then
if settings.method == "noise" then
settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
settings.model_dir, settings.noise_level)
elseif settings.method == "scale" then
settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
settings.model_dir, settings.scale)
else
error("unknown method: " .. settings.method)
end
else
if settings.method == "noise" then
settings.model_file = string.format("%s/noise%d_model.t7",
settings.model_dir, settings.noise_level)
elseif settings.method == "scale" then
settings.model_file = string.format("%s/scale%.1fx_model.t7",
settings.model_dir, settings.scale)
else
error("unknown method: " .. settings.method)
end
end
if not (settings.color == "rgb" or settings.color == "y") then
error("color must be y or rgb")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -23,6 +23,7 @@ cmd:option("-noise_level", 1, 'model noise level')
cmd:option("-jpeg_quality", 75, 'jpeg quality')
cmd:option("-jpeg_times", 1, 'jpeg compression times')
cmd:option("-jpeg_quality_down", 5, 'value of jpeg quality to decrease each times')
cmd:option("-range_bug", 0, 'Reproducing the dynamic range bug that is caused by MATLAB\'s rgb2ycbcr(1|0)')
local opt = cmd:parse(arg)
torch.setdefaulttensortype('torch.FloatTensor')
@ -41,25 +42,33 @@ local function rgb2y_matlab(x)
return y:byte():float()
end
local function MSE(x1, x2)
local function RGBMSE(x1, x2)
x1 = iproc.float2byte(x1):float()
x2 = iproc.float2byte(x2):float()
return (x1 - x2):pow(2):mean()
end
local function YMSE(x1, x2)
local x1_2 = rgb2y_matlab(x1)
local x2_2 = rgb2y_matlab(x2)
return (x1_2 - x2_2):pow(2):mean()
if opt.range_bug == 1 then
local x1_2 = rgb2y_matlab(x1)
local x2_2 = rgb2y_matlab(x2)
return (x1_2 - x2_2):pow(2):mean()
else
local x1_2 = image.rgb2y(x1):mul(255.0)
local x2_2 = image.rgb2y(x2):mul(255.0)
return (x1_2 - x2_2):pow(2):mean()
end
end
local function PSNR(x1, x2)
local mse = MSE(x1, x2)
local function MSE(x1, x2, color)
if color == "y" then
return YMSE(x1, x2)
else
return RGBMSE(x1, x2)
end
end
local function PSNR(x1, x2, color)
local mse = MSE(x1, x2, color)
return 10 * math.log10((255.0 * 255.0) / mse)
end
local function YPSNR(x1, x2)
local mse = YMSE(x1, x2)
return 10 * math.log10((255.0 * 255.0) / mse)
end
local function transform_jpeg(x, opt)
for i = 1, opt.jpeg_times do
jpeg = gm.Image(x, "RGB", "DHW")
@ -69,7 +78,7 @@ local function transform_jpeg(x, opt)
jpeg:fromBlob(blob, len)
x = jpeg:toTensor("byte", "RGB", "DHW")
end
return x
return iproc.byte2float(x)
end
local function baseline_scale(x, filter)
return iproc.scale(x,
@ -110,62 +119,47 @@ local function benchmark(opt, x, input_func, model1, model2)
end
baseline_output = baseline_scale(input, opt.filter)
end
if opt.color == "y" then
model1_mse = model1_mse + YMSE(ground_truth, model1_output)
model1_psnr = model1_psnr + YPSNR(ground_truth, model1_output)
if model2 then
model2_mse = model2_mse + YMSE(ground_truth, model2_output)
model2_psnr = model2_psnr + YPSNR(ground_truth, model2_output)
end
if baseline_output then
baseline_mse = baseline_mse + YMSE(ground_truth, baseline_output)
baseline_psnr = baseline_psnr + YPSNR(ground_truth, baseline_output)
end
elseif opt.color == "rgb" then
model1_mse = model1_mse + MSE(ground_truth, model1_output)
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output)
if model2 then
model2_mse = model2_mse + MSE(ground_truth, model2_output)
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output)
end
if baseline_output then
baseline_mse = baseline_mse + MSE(ground_truth, baseline_output)
baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output)
end
else
error("Unknown color: " .. opt.color)
model1_mse = model1_mse + MSE(ground_truth, model1_output, opt.color)
model1_psnr = model1_psnr + PSNR(ground_truth, model1_output, opt.color)
if model2 then
model2_mse = model2_mse + MSE(ground_truth, model2_output, opt.color)
model2_psnr = model2_psnr + PSNR(ground_truth, model2_output, opt.color)
end
if baseline_output then
baseline_mse = baseline_mse + MSE(ground_truth, baseline_output, opt.color)
baseline_psnr = baseline_psnr + PSNR(ground_truth, baseline_output, opt.color)
end
if model2 then
if baseline_output then
io.stdout:write(
string.format("%d/%d; baseline_mse=%f, model1_mse=%f, model2_mse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, model2_rmse=%f, baseline_psnr=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
baseline_mse / i,
model1_mse / i, model2_mse / i,
math.sqrt(baseline_mse / i),
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
baseline_psnr / i,
model1_psnr / i, model2_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model2_mse=%f, model1_psnr=%f, model2_psnr=%f \r",
string.format("%d/%d; model1_rmse=%f, model2_rmse=%f, model1_psnr=%f, model2_psnr=%f \r",
i, #x,
model1_mse / i, model2_mse / i,
math.sqrt(model1_mse / i), math.sqrt(model2_mse / i),
model1_psnr / i, model2_psnr / i
))
end
else
if baseline_output then
io.stdout:write(
string.format("%d/%d; baseline_mse=%f, model1_mse=%f, baseline_psnr=%f, model1_psnr=%f \r",
string.format("%d/%d; baseline_rmse=%f, model1_rmse=%f, baseline_psnr=%f, model1_psnr=%f \r",
i, #x,
baseline_mse / i, model1_mse / i,
math.sqrt(baseline_mse / i), math.sqrt(model1_mse / i),
baseline_psnr / i, model1_psnr / i
))
else
io.stdout:write(
string.format("%d/%d; model1_mse=%f, model1_psnr=%f \r",
string.format("%d/%d; model1_rmse=%f, model1_psnr=%f \r",
i, #x,
model1_mse / i, model1_psnr / i
math.sqrt(model1_mse / i), model1_psnr / i
))
end
end

120
train.lua
View file

@ -35,14 +35,14 @@ local function split_data(x, test_size)
end
return train_x, valid_x
end
local function make_validation_set(x, transformer, n, batch_size)
local function make_validation_set(x, transformer, n, patches)
n = n or 4
local data = {}
for i = 1, #x do
for k = 1, math.max(n / batch_size, 1) do
local xy = transformer(x[i], true, batch_size)
local tx = torch.Tensor(batch_size, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
local ty = torch.Tensor(batch_size, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
for k = 1, math.max(n / patches, 1) do
local xy = transformer(x[i], true, patches)
local tx = torch.Tensor(patches, xy[1][1]:size(1), xy[1][1]:size(2), xy[1][1]:size(3))
local ty = torch.Tensor(patches, xy[1][2]:size(1), xy[1][2]:size(2), xy[1][2]:size(3))
for j = 1, #xy do
tx[j]:copy(xy[j][1])
ty[j]:copy(xy[j][2])
@ -83,7 +83,8 @@ local function create_criterion(model)
end
local function transformer(x, is_validation, n, offset)
x = compression.decompress(x)
n = n or settings.batch_size;
n = n or settings.patches
if is_validation == nil then is_validation = false end
local random_color_noise_rate = nil
local random_overlay_rate = nil
@ -110,6 +111,7 @@ local function transformer(x, is_validation, n, offset)
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
@ -125,8 +127,9 @@ local function transformer(x, is_validation, n, offset)
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
random_unsharp_mask_rate = settings.random_unsharp_mask_rate,
max_size = settings.max_size,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
jpeg_chroma_subsampling_rate = settings.jpeg_chroma_subsampling_rate,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
nr_rate = settings.nr_rate,
@ -135,7 +138,24 @@ local function transformer(x, is_validation, n, offset)
end
end
local function resampling(x, y, train_x, transformer, input_size, target_size)
print("## resampling")
for t = 1, #train_x do
xlua.progress(t, #train_x)
local xy = transformer(train_x[t], false, settings.patches)
for i = 1, #xy do
local index = (t - 1) * settings.patches + i
x[index]:copy(xy[i][1])
y[index]:copy(xy[i][2])
end
if t % 50 == 0 then
collectgarbage()
end
end
end
local function train()
local LR_MIN = 1.0e-5
local model = srcnn.create(settings.method, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
local pairwise_func = function(x, is_validation, n)
@ -143,12 +163,12 @@ local function train()
end
local criterion = create_criterion(model)
local x = torch.load(settings.images)
local lrd_count = 0
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local adam_config = {
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,
}
local lrd_count = 0
local ch = nil
if settings.color == "y" then
ch = 1
@ -159,48 +179,70 @@ local function train()
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, pairwise_func,
settings.validation_crops,
settings.batch_size)
settings.patches)
valid_x = nil
collectgarbage()
model:cuda()
print("load .. " .. #train_x)
local x = torch.Tensor(settings.patches * #train_x,
ch, settings.crop_size, settings.crop_size)
local y = torch.Tensor(settings.patches * #train_x,
ch * (settings.crop_size - offset * 2) * (settings.crop_size - offset * 2)):zero()
for epoch = 1, settings.epoch do
model:training()
print("# " .. epoch)
print(minibatch_adam(model, criterion, train_x, adam_config,
pairwise_func,
{ch, settings.crop_size, settings.crop_size},
{ch, settings.crop_size - offset * 2, settings.crop_size - offset * 2}
))
model:evaluate()
print("# validation")
local score = validate(model, criterion, valid_xy)
if score < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
best_score = score
print("* update best model")
torch.save(settings.model_file, model)
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level))
save_test_jpeg(model, test_image, log)
elseif settings.method == "scale" then
local log = path.join(settings.model_dir,
("scale%.1f_best.png"):format(settings.scale))
save_test_scale(model, test_image, log)
end
else
lrd_count = lrd_count + 1
if lrd_count > 5 then
resampling(x, y, train_x, pairwise_func)
for i = 1, settings.inner_epoch do
print(minibatch_adam(model, criterion, x, y, adam_config))
model:evaluate()
print("# validation")
local score = validate(model, criterion, valid_xy)
if score < best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
adam_config.learningRate = adam_config.learningRate * 0.9
print("* learning rate decay: " .. adam_config.learningRate)
best_score = score
print("* update best model")
if settings.save_history then
local model_clone = model:clone()
w2nn.cleanup_model(model_clone)
torch.save(string.format(settings.model_file, epoch, i), model_clone)
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.%d-%d.png"):format(settings.noise_level,
epoch, i))
save_test_jpeg(model, test_image, log)
elseif settings.method == "scale" then
local log = path.join(settings.model_dir,
("scale%.1f_best.%d-%d.png"):format(settings.scale,
epoch, i))
save_test_scale(model, test_image, log)
end
else
torch.save(settings.model_file, model)
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level))
save_test_jpeg(model, test_image, log)
elseif settings.method == "scale" then
local log = path.join(settings.model_dir,
("scale%.1f_best.png"):format(settings.scale))
save_test_scale(model, test_image, log)
end
end
else
lrd_count = lrd_count + 1
if lrd_count > 2 and adam_config.learningRate > LR_MIN then
adam_config.learningRate = adam_config.learningRate * 0.8
print("* learning rate decay: " .. adam_config.learningRate)
lrd_count = 0
end
end
print("current: " .. score .. ", best: " .. best_score)
collectgarbage()
end
print("current: " .. score .. ", best: " .. best_score)
collectgarbage()
end
end
if settings.gpu > 0 then

14
web.lua
View file

@ -12,8 +12,9 @@ local iproc = require 'iproc'
local reconstruct = require 'reconstruct'
local image_loader = require 'image_loader'
local alpha_util = require 'alpha_util'
local gm = require 'graphicsmagick'
-- Notes: turbo and xlua has different implementation of string:split().
-- Note: turbo and xlua has different implementation of string:split().
-- Therefore, string:split() has conflict issue.
-- In this script, use turbo's string:split().
local turbo = require 'turbo'
@ -36,13 +37,13 @@ if cudnn then
cudnn.benchmark = false
end
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "ukbench")
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
--local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
--local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
--local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
local CACHE_DIR = path.join(ROOT, "cache")
local MAX_NOISE_IMAGE = 2560 * 2560
@ -143,7 +144,7 @@ local function convert(x, alpha, options)
x = reconstruct.image(art_noise2_model, x)
cleanup_model(art_noise2_model)
end
else --[[photo
else -- photo
if options.border then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(photo_scale2_model))
end
@ -163,7 +164,6 @@ local function convert(x, alpha, options)
x = reconstruct.image(photo_noise2_model, x)
cleanup_model(photo_noise2_model)
end
--]]
end
image_loader.save_png(cache_file, x)