1
0
Fork 0
mirror of synced 2024-06-03 03:24:33 +12:00

Merge pull request #95 from nagadomi/dev

Merge from dev branch
This commit is contained in:
nagadomi 2016-03-27 18:03:08 +09:00
commit 7849f51f42
44 changed files with 208 additions and 165 deletions

38
appendix/benchmark.md Normal file
View file

@ -0,0 +1,38 @@
# Benchmark results
## dataset
photo_set: 300 various photos.
art_set : 90 artworks (PNG only).
## 2x upscaling model
| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo | ukbench|
|---------------|----------------------|------------------------|---------|--------|
| photo\_test | 29.83 | 29.81 |**29.89**| 29.86 |
| art\_test | 36.02 | **36.24**| 34.92 | 34.85 |
The evaluation metric is PSNR(Y only), higher is better.
## Denosing level 1 model
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|--------------------------|-------------------|------------------------|---------|
| photo\_test Quality 80 | 36.07 | **36.20**| 36.01 |
| photo\_test Quality 50,45| 31.72 | 32.01 |**32.31**|
| art\_test Quality 80 | 40.39 | **42.48**| 40.35 |
| art\_test Quality 50,45 | 35.45 | **36.70**| 36.27 |
The evaluation metric is PSNR(RGB), higher is better.
## Denosing level 2 model
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|--------------------------|-------------------|------------------------|---------|
| photo\_test Quality 80 | 34.03 | 34.42 |**36.06**|
| photo\_test Quality 50,45| 31.95 | 32.31 |**32.42**|
| art\_test Quality 80 | 39.20 | **41.12**| 40.48 |
| art\_test Quality 50,45 | 36.14 | **37.78**| 36.55 |
The evaluation metric is PSNR(RGB), higher is better.

View file

@ -106,6 +106,12 @@
Alto
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
Highest
</span>
</label>
</div>
<div class="option-hint">
Es necesario utilizar la reducción de ruido si la imagen dispone de artefactos de compresión; de lo contrario podría producir el efecto opuesto.

View file

@ -106,6 +106,12 @@
Haute
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
Highest
</span>
</label>
</div>
<div class="option-hint">
Il est nécessaire d'utiliser la réduction du bruit si l'image possède du bruit. Autrement, cela risque de causer l'effet opposé.

View file

@ -106,6 +106,12 @@
High
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
Highest
</span>
</label>
</div>
<div class="option-hint">
You need use noise reduction if image actually has noise or it may cause opposite effect.

View file

@ -106,6 +106,12 @@
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
最高
</span>
</label>
</div>
<div class="option-hint">
イズ除去は細部が消えることがあります。JPEGイズがある場合に使用します。

View file

@ -106,6 +106,12 @@
Alta
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
Highest
</span>
</label>
</div>
<div class="option-hint">
Quando usando a escala 2x, Nós nunca recomendamos usar um nível alto de redução de ruído, quase sempre deixa a imagem pior, faz sentido apenas para casos raros quando a imagem tinha uma qualidade muito má desde o começo.

View file

@ -106,6 +106,12 @@
Сильно
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
Highest
</span>
</label>
</div>
<div class="option-hint">
Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект.

View file

@ -17,7 +17,7 @@ function LeakyReLU:updateOutput(input)
return self.output
end
function LeakyReLU:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput)
-- filter positive
@ -29,3 +29,8 @@ function LeakyReLU:updateGradInput(input, gradOutput)
return self.gradInput
end
function LeakyReLU:clearState()
nn.utils.clear(self, 'negative')
return parent.clearState(self)
end

19
lib/PSNRCriterion.lua Normal file
View file

@ -0,0 +1,19 @@
local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
function PSNRCriterion:__init()
parent.__init(self)
self.image = torch.Tensor()
self.diff = torch.Tensor()
end
function PSNRCriterion:updateOutput(input, target)
self.image:resizeAs(input):copy(input)
self.image:clamp(0.0, 1.0)
self.diff:resizeAs(self.image):copy(self.image)
local mse = math.max(self.diff:add(-1, target):pow(2):mean(), (0.1/255)^2)
self.output = 10 * math.log10(1.0 / mse)
return self.output
end
function PSNRCriterion:updateGradInput(input, target)
error("PSNRCriterion does not support backward")
end

View file

@ -1,48 +1,3 @@
-- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
local function zeroDataSize(data)
if type(data) == 'table' then
for i = 1, #data do
data[i] = zeroDataSize(data[i])
end
elseif type(data) == 'userdata' then
data = torch.Tensor():typeAs(data)
end
return data
end
-- Resize the output, gradInput, etc temporary tensors to zero (so that the
-- on disk size is smaller)
local function cleanupModel(node)
if node.output ~= nil then
node.output = zeroDataSize(node.output)
end
if node.gradInput ~= nil then
node.gradInput = zeroDataSize(node.gradInput)
end
if node.finput ~= nil then
node.finput = zeroDataSize(node.finput)
end
if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
if node.negative ~= nil then
node.negative = zeroDataSize(node.negative)
end
end
if tostring(node) == "nn.Dropout" then
if node.noise ~= nil then
node.noise = zeroDataSize(node.noise)
end
end
-- Recurse on nodes with 'modules'
if (node.modules ~= nil) then
if (type(node.modules) == 'table') then
for i = 1, #node.modules do
local child = node.modules[i]
cleanupModel(child)
end
end
end
end
function w2nn.cleanup_model(model)
cleanupModel(model)
return model
return model:clearState()
end

View file

@ -2,12 +2,13 @@ require 'optim'
require 'cutorch'
require 'xlua'
local function minibatch_adam(model, criterion,
local function minibatch_adam(model, criterion, eval_metric,
train_x, train_y,
config)
local parameters, gradParameters = model:getParameters()
config = config or {}
local sum_loss = 0
local sum_eval = 0
local count_loss = 0
local batch_size = config.xBatchSize or 32
local shuffle = torch.randperm(train_x:size(1))
@ -39,6 +40,7 @@ local function minibatch_adam(model, criterion,
gradParameters:zero()
local output = model:forward(inputs)
local f = criterion:forward(output, targets)
sum_eval = sum_eval + eval_metric:forward(output, targets)
sum_loss = sum_loss + f
count_loss = count_loss + 1
model:backward(inputs, criterion:backward(output, targets))
@ -52,7 +54,7 @@ local function minibatch_adam(model, criterion,
end
xlua.progress(train_x:size(1), train_x:size(1))
return { loss = sum_loss / count_loss}
return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
end
return minibatch_adam

View file

@ -82,30 +82,14 @@ local function active_cropping(x, y, size, p, tries)
end
end
function pairwise_transform.scale(src, scale, size, offset, n, options)
local filters;
if options.style == "photo" then
filters = {
"Box", "lanczos", "Catrom"
}
else
filters = {
"Box","Box", -- 0.012756949974688
"Blackman", -- 0.013191924552285
--"Catrom", -- 0.013753536746706
--"Hanning", -- 0.013761314529647
--"Hermite", -- 0.013850225205266
"Sinc", -- 0.014095824314306
"Lanczos", -- 0.014244299255442
}
end
local filters = options.downsampling_filters
local unstable_region_offset = 8
local downscale_filter = filters[torch.random(1, #filters)]
local downsampling_filter = filters[torch.random(1, #filters)]
local y = preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
local down_scale = 1.0 / scale
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downscale_filter),
y:size(2) * down_scale, downsampling_filter),
y:size(3), y:size(2))
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
@ -184,7 +168,8 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if level == 1 then
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
size, offset, n, options)
elseif level == 2 then
elseif level == 2 or level == 3 then
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
local r = torch.uniform()
if r > 0.6 then
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},

View file

@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'path to test image')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-noise_level", 1, '(1|2|3)')
cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
@ -42,16 +42,24 @@ cmd:option("-epoch", 30, 'number of epochs to run')
cmd:option("-thread", -1, 'number of CPU threads')
cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
cmd:option("-validation_crops", 160, 'number of cropping region per image in validation')
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
cmd:option("-save_history", 0, 'save all model (0|1)')
cmd:option("-plot", 0, 'plot loss chart(0|1)')
cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
local opt = cmd:parse(arg)
for k, v in pairs(opt) do
settings[k] = v
end
if settings.plot == 1 then
settings.plot = true
require 'gnuplot'
else
settings.plot = false
end
if settings.save_history == 1 then
settings.save_history = true
else
@ -88,10 +96,14 @@ if not (settings.style == "art" or
settings.style == "photo") then
error(string.format("unknown style: %s", settings.style))
end
if settings.thread > 0 then
torch.setnumthreads(tonumber(settings.thread))
end
if settings.downsampling_filters and settings.downsampling_filters:len() > 0 then
settings.downsampling_filters = settings.downsampling_filters:split(",")
else
settings.downsampling_filters = {"Box", "Lanczos", "Catrom"}
end
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)

View file

@ -17,6 +17,16 @@ if cudnn and cudnn.SpatialConvolution then
end
end
function nn.SpatialConvolutionMM:clearState()
if self.gradWeight then
self.gradWeight = torch.Tensor(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):typeAs(self.gradWeight):zero()
end
if self.gradBias then
self.gradBias = torch.Tensor(self.nOutputPlane):typeAs(self.gradBias):zero()
end
return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
end
function srcnn.channels(model)
return model:get(model:size() - 1).weight:size(1)
end

View file

@ -19,8 +19,7 @@ else
require 'LeakyReLU'
require 'LeakyReLU_deprecated'
require 'DepthExpand2x'
require 'WeightedMSECriterion'
require 'PSNRCriterion'
require 'ClippedWeightedHuberCriterion'
require 'cleanup_model'
return w2nn
end

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -1,25 +0,0 @@
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'w2nn'
torch.setdefaulttensortype("torch.FloatTensor")
local cmd = torch.CmdLine()
cmd:text()
cmd:text("cleanup model")
cmd:text("Options:")
cmd:option("-model", "./model.t7", 'path of model file')
cmd:option("-iformat", "binary", 'input format')
cmd:option("-oformat", "binary", 'output format')
local opt = cmd:parse(arg)
local model = torch.load(opt.model, opt.iformat)
if model then
w2nn.cleanup_model(model)
model:cuda()
model:evaluate()
torch.save(opt.model, model, opt.oformat)
else
error("model not found")
end

View file

@ -24,11 +24,6 @@ function export(model, output)
}
table.insert(jmodules, jmod)
end
jmodules[1].color = "RGB"
jmodules[1].gamma = 0
jmodules[#jmodules].color = "RGB"
jmodules[#jmodules].gamma = 0
local fp = io.open(output, "w")
if not fp then
error("IO Error: " .. output)

View file

@ -78,7 +78,11 @@ local function create_criterion(model)
weight[3]:fill(0.11448 * 3) -- B
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
else
return nn.MSECriterion():cuda()
local offset = reconstruct.offset_size(model)
local output_w = settings.crop_size - offset * 2
local weight = torch.Tensor(1, output_w * output_w)
weight[1]:fill(1.0)
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
end
end
local function transformer(x, is_validation, n, offset)
@ -91,8 +95,8 @@ local function transformer(x, is_validation, n, offset)
local active_cropping_rate = nil
local active_cropping_tries = nil
if is_validation then
active_cropping_rate = 0
active_cropping_tries = 0
active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries
random_color_noise_rate = 0.0
random_overlay_rate = 0.0
else
@ -108,6 +112,7 @@ local function transformer(x, is_validation, n, offset)
settings.crop_size, offset,
n,
{
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
@ -153,8 +158,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
end
end
end
local function plot(train, valid)
gnuplot.plot({
{'training', torch.Tensor(train), '-'},
{'validation', torch.Tensor(valid), '-'}})
end
local function train()
local hist_train = {}
local hist_valid = {}
local LR_MIN = 1.0e-5
local model = srcnn.create(settings.method, settings.backend, settings.color)
local offset = reconstruct.offset_size(model)
@ -162,6 +173,7 @@ local function train()
return transformer(x, is_validation, n, offset)
end
local criterion = create_criterion(model)
local eval_metric = w2nn.PSNRCriterion():cuda()
local x = torch.load(settings.images)
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local adam_config = {
@ -175,7 +187,7 @@ local function train()
elseif settings.color == "rgb" then
ch = 3
end
local best_score = 100000.0
local best_score = 0.0
print("# make validation-set")
local valid_xy = make_validation_set(valid_x, pairwise_func,
settings.validation_crops,
@ -196,19 +208,24 @@ local function train()
print("# " .. epoch)
resampling(x, y, train_x, pairwise_func)
for i = 1, settings.inner_epoch do
print(minibatch_adam(model, criterion, x, y, adam_config))
local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
print(train_score)
model:evaluate()
print("# validation")
local score = validate(model, criterion, valid_xy)
if score < best_score then
local score = validate(model, eval_metric, valid_xy)
table.insert(hist_train, train_score.PSNR)
table.insert(hist_valid, score)
if settings.plot then
plot(hist_train, hist_valid)
end
if score > best_score then
local test_image = image_loader.load_float(settings.test) -- reload
lrd_count = 0
best_score = score
print("* update best model")
if settings.save_history then
local model_clone = model:clone()
w2nn.cleanup_model(model_clone)
torch.save(string.format(settings.model_file, epoch, i), model_clone)
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.%d-%d.png"):format(settings.noise_level,
@ -221,7 +238,7 @@ local function train()
save_test_scale(model, test_image, log)
end
else
torch.save(settings.model_file, model)
torch.save(settings.model_file, model:clearState(), "ascii")
if settings.method == "noise" then
local log = path.join(settings.model_dir,
("noise%d_best.png"):format(settings.noise_level))

View file

@ -3,10 +3,5 @@
th convert_data.lua
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii

View file

@ -2,11 +2,8 @@
th convert_data.lua -style photo -data_dir ./data/photo -model_dir models/photo
th train.lua -style photo -method scale -data_dir ./data/photo -model_dir models/photo_uk -test work/scale_test_photo.png -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.1 -validation_crops 160
th tools/cleanup_model.lua -model models/photo/scale2.0x_model.t7 -oformat ascii
th train.lua -style photo -method scale -data_dir ./data/photo -model_dir models/photo -test work/scale_test_photo.png -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.1 -validation_crops 160
th train.lua -style photo -method noise -noise_level 1 -data_dir ./data/photo -model_dir models/photo -test work/noise_test_photo.jpg -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.5 -validation_crops 160 -nr_rate 0.6 -epoch 33
th tools/cleanup_model.lua -model models/photo/noise1_model.t7 -oformat ascii
th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/photo -model_dir models/photo -test work/noise_test_photo.jpg -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.5 -validation_crops 160 -nr_rate 0.8 -epoch 38
th tools/cleanup_model.lua -model models/photo/noise2_model.t7 -oformat ascii

View file

@ -69,7 +69,8 @@ local function convert_image(opt)
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
end
local function convert_frames(opt)
local model_path, noise1_model, noise2_model, scale_model
local model_path, scale_model
local noise_model = {}
local scale_f, image_f
if opt.tta == 1 then
scale_f = reconstruct.scale_tta
@ -84,16 +85,10 @@ local function convert_frames(opt)
if not scale_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 1 then
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise" and opt.noise_level == 2 then
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
elseif opt.m == "noise" then
model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
noise_model[opt.noise_level] = torch.load(model_path, "ascii")
if not noise_model[opt.noise_level] then
error("Load Error: " .. model_path)
end
elseif opt.m == "noise_scale" then
@ -102,18 +97,10 @@ local function convert_frames(opt)
if not scale_model then
error("Load Error: " .. model_path)
end
if opt.noise_level == 1 then
model_path = path.join(opt.model_dir, "noise1_model.t7")
noise1_model = torch.load(model_path, "ascii")
if not noise1_model then
error("Load Error: " .. model_path)
end
elseif opt.noise_level == 2 then
model_path = path.join(opt.model_dir, "noise2_model.t7")
noise2_model = torch.load(model_path, "ascii")
if not noise2_model then
error("Load Error: " .. model_path)
end
model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
noise_model[opt.noise_level] = torch.load(model_path, "ascii")
if not noise_model[opt.noise_level] then
error("Load Error: " .. model_path)
end
end
local fp = io.open(opt.l)
@ -130,24 +117,16 @@ local function convert_frames(opt)
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
local x, alpha = image_loader.load_float(lines[i])
local new_x = nil
if opt.m == "noise" and opt.noise_level == 1 then
new_x = image_f(noise1_model, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha)
elseif opt.m == "noise" and opt.noise_level == 2 then
new_x = image_f(noise2_model, x, opt.crop_size)
if opt.m == "noise" then
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha)
elseif opt.m == "scale" then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
elseif opt.m == "noise_scale" then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
x = image_f(noise1_model, x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
x = image_f(noise2_model, x, opt.crop_size)
x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
new_x = alpha_util.composite(new_x, alpha, scale_model)
else
@ -185,7 +164,7 @@ local function waifu2x()
cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-noise_level", 1, '(1|2|3)')
cmd:option("-crop_size", 128, 'patch size per process')
cmd:option("-resume", 0, "skip existing files (0|1)")
cmd:option("-thread", -1, "number of CPU threads")

24
web.lua
View file

@ -38,12 +38,14 @@ if cudnn then
end
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local art_noise3_model = torch.load(path.join(ART_MODEL_DIR, "noise3_model.t7"), "ascii")
local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
local photo_noise3_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), "ascii")
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
local CACHE_DIR = path.join(ROOT, "cache")
local MAX_NOISE_IMAGE = 2560 * 2560
@ -114,7 +116,7 @@ local function get_image(req)
end
local function cleanup_model(model)
if CLEANUP_MODEL then
w2nn.cleanup_model(model) -- release GPU memory
model:clearState() -- release GPU memory
end
end
local function convert(x, alpha, options)
@ -151,9 +153,12 @@ local function convert(x, alpha, options)
elseif options.method == "noise1" then
x = reconstruct.image(art_noise1_model, x)
cleanup_model(art_noise1_model)
else -- options.method == "noise2"
elseif options.method == "noise2" then
x = reconstruct.image(art_noise2_model, x)
cleanup_model(art_noise2_model)
elseif options.method == "noise3" then
x = reconstruct.image(art_noise3_model, x)
cleanup_model(art_noise3_model)
end
else -- photo
if options.border then
@ -174,6 +179,9 @@ local function convert(x, alpha, options)
elseif options.method == "noise2" then
x = reconstruct.image(photo_noise2_model, x)
cleanup_model(photo_noise2_model)
elseif options.method == "noise3" then
x = reconstruct.image(photo_noise3_model, x)
cleanup_model(photo_noise3_model)
end
end
image_loader.save_png(cache_file, x)
@ -229,17 +237,25 @@ function APIHandler:post()
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 2 then
prefix = style .. "_noise1_"
prefix = style .. "_noise2_"
x = convert(x, alpha, {method = "noise2", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
elseif noise == 3 then
prefix = style .. "_noise3_"
x = convert(x, alpha, {method = "noise3", style = style,
prefix = prefix .. hash,
alpha_prefix = alpha_prefix, border = border})
border = false
end
if scale == 1 or scale == 2 then
if noise == 1 then
prefix = style .. "_noise1_scale_"
elseif noise == 2 then
prefix = style .. "_noise2_scale_"
elseif noise == 3 then
prefix = style .. "_noise3_scale_"
else
prefix = style .. "_scale_"
end

View file

@ -14,6 +14,7 @@ expect_jpeg: expect JPEG artifact
nr_none: None
nr_medium: Medium
nr_high: High
nr_highest: Highest
nr_hint: "You need use noise reduction if image actually has noise or it may cause opposite effect."
upscaling: Upscaling
up_none: None

View file

@ -14,6 +14,7 @@ expect_jpeg: JPEGイズを想定
nr_none: なし
nr_medium:
nr_high:
nr_highest: 最高
nr_hint: "イズ除去は細部が消えることがあります。JPEGイズがある場合に使用します。"
upscaling: 拡大
up_none: なし

View file

@ -106,6 +106,12 @@
<%= t[:nr_high] %>
</span>
</label>
<label>
<input type="radio" name="noise" class="radio" value="3">
<span class="r-text">
<%= t[:nr_highest] %>
</span>
</label>
</div>
<div class="option-hint">
<%= t[:nr_hint] %>