commit
7849f51f42
38
appendix/benchmark.md
Normal file
38
appendix/benchmark.md
Normal file
|
@ -0,0 +1,38 @@
|
|||
# Benchmark results
|
||||
|
||||
|
||||
## dataset
|
||||
|
||||
photo_set: 300 various photos.
|
||||
art_set : 90 artworks (PNG only).
|
||||
|
||||
## 2x upscaling model
|
||||
|
||||
| Dataset/Model | anime\_style\_art(Y) | anime\_style\_art\_rgb | photo | ukbench|
|
||||
|---------------|----------------------|------------------------|---------|--------|
|
||||
| photo\_test | 29.83 | 29.81 |**29.89**| 29.86 |
|
||||
| art\_test | 36.02 | **36.24**| 34.92 | 34.85 |
|
||||
|
||||
The evaluation metric is PSNR(Y only), higher is better.
|
||||
|
||||
## Denosing level 1 model
|
||||
|
||||
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|
||||
|--------------------------|-------------------|------------------------|---------|
|
||||
| photo\_test Quality 80 | 36.07 | **36.20**| 36.01 |
|
||||
| photo\_test Quality 50,45| 31.72 | 32.01 |**32.31**|
|
||||
| art\_test Quality 80 | 40.39 | **42.48**| 40.35 |
|
||||
| art\_test Quality 50,45 | 35.45 | **36.70**| 36.27 |
|
||||
|
||||
The evaluation metric is PSNR(RGB), higher is better.
|
||||
|
||||
## Denosing level 2 model
|
||||
|
||||
| Dataset/Model | anime\_style\_art | anime\_style\_art\_rgb | photo |
|
||||
|--------------------------|-------------------|------------------------|---------|
|
||||
| photo\_test Quality 80 | 34.03 | 34.42 |**36.06**|
|
||||
| photo\_test Quality 50,45| 31.95 | 32.31 |**32.42**|
|
||||
| art\_test Quality 80 | 39.20 | **41.12**| 40.48 |
|
||||
| art\_test Quality 50,45 | 36.14 | **37.78**| 36.55 |
|
||||
|
||||
The evaluation metric is PSNR(RGB), higher is better.
|
|
@ -106,6 +106,12 @@
|
|||
Alto
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
Highest
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
Es necesario utilizar la reducción de ruido si la imagen dispone de artefactos de compresión; de lo contrario podría producir el efecto opuesto.
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
Haute
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
Highest
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
Il est nécessaire d'utiliser la réduction du bruit si l'image possède du bruit. Autrement, cela risque de causer l'effet opposé.
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
High
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
Highest
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
You need use noise reduction if image actually has noise or it may cause opposite effect.
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
高
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
最高
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
ノイズ除去は細部が消えることがあります。JPEGノイズがある場合に使用します。
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
Alta
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
Highest
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
Quando usando a escala 2x, Nós nunca recomendamos usar um nível alto de redução de ruído, quase sempre deixa a imagem pior, faz sentido apenas para casos raros quando a imagem tinha uma qualidade muito má desde o começo.
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
Сильно
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
Highest
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
Устранение шума нужно использовать, если на картинке действительно есть шум, иначе это даст противоположный эффект.
|
||||
|
|
|
@ -17,7 +17,7 @@ function LeakyReLU:updateOutput(input)
|
|||
|
||||
return self.output
|
||||
end
|
||||
|
||||
|
||||
function LeakyReLU:updateGradInput(input, gradOutput)
|
||||
self.gradInput:resizeAs(gradOutput)
|
||||
-- filter positive
|
||||
|
@ -29,3 +29,8 @@ function LeakyReLU:updateGradInput(input, gradOutput)
|
|||
|
||||
return self.gradInput
|
||||
end
|
||||
|
||||
function LeakyReLU:clearState()
|
||||
nn.utils.clear(self, 'negative')
|
||||
return parent.clearState(self)
|
||||
end
|
||||
|
|
19
lib/PSNRCriterion.lua
Normal file
19
lib/PSNRCriterion.lua
Normal file
|
@ -0,0 +1,19 @@
|
|||
local PSNRCriterion, parent = torch.class('w2nn.PSNRCriterion','nn.Criterion')
|
||||
|
||||
function PSNRCriterion:__init()
|
||||
parent.__init(self)
|
||||
self.image = torch.Tensor()
|
||||
self.diff = torch.Tensor()
|
||||
end
|
||||
function PSNRCriterion:updateOutput(input, target)
|
||||
self.image:resizeAs(input):copy(input)
|
||||
self.image:clamp(0.0, 1.0)
|
||||
self.diff:resizeAs(self.image):copy(self.image)
|
||||
|
||||
local mse = math.max(self.diff:add(-1, target):pow(2):mean(), (0.1/255)^2)
|
||||
self.output = 10 * math.log10(1.0 / mse)
|
||||
return self.output
|
||||
end
|
||||
function PSNRCriterion:updateGradInput(input, target)
|
||||
error("PSNRCriterion does not support backward")
|
||||
end
|
|
@ -1,48 +1,3 @@
|
|||
-- ref: https://github.com/torch/nn/issues/112#issuecomment-64427049
|
||||
|
||||
local function zeroDataSize(data)
|
||||
if type(data) == 'table' then
|
||||
for i = 1, #data do
|
||||
data[i] = zeroDataSize(data[i])
|
||||
end
|
||||
elseif type(data) == 'userdata' then
|
||||
data = torch.Tensor():typeAs(data)
|
||||
end
|
||||
return data
|
||||
end
|
||||
-- Resize the output, gradInput, etc temporary tensors to zero (so that the
|
||||
-- on disk size is smaller)
|
||||
local function cleanupModel(node)
|
||||
if node.output ~= nil then
|
||||
node.output = zeroDataSize(node.output)
|
||||
end
|
||||
if node.gradInput ~= nil then
|
||||
node.gradInput = zeroDataSize(node.gradInput)
|
||||
end
|
||||
if node.finput ~= nil then
|
||||
node.finput = zeroDataSize(node.finput)
|
||||
end
|
||||
if tostring(node) == "nn.LeakyReLU" or tostring(node) == "w2nn.LeakyReLU" then
|
||||
if node.negative ~= nil then
|
||||
node.negative = zeroDataSize(node.negative)
|
||||
end
|
||||
end
|
||||
if tostring(node) == "nn.Dropout" then
|
||||
if node.noise ~= nil then
|
||||
node.noise = zeroDataSize(node.noise)
|
||||
end
|
||||
end
|
||||
-- Recurse on nodes with 'modules'
|
||||
if (node.modules ~= nil) then
|
||||
if (type(node.modules) == 'table') then
|
||||
for i = 1, #node.modules do
|
||||
local child = node.modules[i]
|
||||
cleanupModel(child)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
function w2nn.cleanup_model(model)
|
||||
cleanupModel(model)
|
||||
return model
|
||||
return model:clearState()
|
||||
end
|
||||
|
|
|
@ -2,12 +2,13 @@ require 'optim'
|
|||
require 'cutorch'
|
||||
require 'xlua'
|
||||
|
||||
local function minibatch_adam(model, criterion,
|
||||
local function minibatch_adam(model, criterion, eval_metric,
|
||||
train_x, train_y,
|
||||
config)
|
||||
local parameters, gradParameters = model:getParameters()
|
||||
config = config or {}
|
||||
local sum_loss = 0
|
||||
local sum_eval = 0
|
||||
local count_loss = 0
|
||||
local batch_size = config.xBatchSize or 32
|
||||
local shuffle = torch.randperm(train_x:size(1))
|
||||
|
@ -39,6 +40,7 @@ local function minibatch_adam(model, criterion,
|
|||
gradParameters:zero()
|
||||
local output = model:forward(inputs)
|
||||
local f = criterion:forward(output, targets)
|
||||
sum_eval = sum_eval + eval_metric:forward(output, targets)
|
||||
sum_loss = sum_loss + f
|
||||
count_loss = count_loss + 1
|
||||
model:backward(inputs, criterion:backward(output, targets))
|
||||
|
@ -52,7 +54,7 @@ local function minibatch_adam(model, criterion,
|
|||
end
|
||||
xlua.progress(train_x:size(1), train_x:size(1))
|
||||
|
||||
return { loss = sum_loss / count_loss}
|
||||
return { loss = sum_loss / count_loss, PSNR = sum_eval / count_loss}
|
||||
end
|
||||
|
||||
return minibatch_adam
|
||||
|
|
|
@ -82,30 +82,14 @@ local function active_cropping(x, y, size, p, tries)
|
|||
end
|
||||
end
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
local filters;
|
||||
|
||||
if options.style == "photo" then
|
||||
filters = {
|
||||
"Box", "lanczos", "Catrom"
|
||||
}
|
||||
else
|
||||
filters = {
|
||||
"Box","Box", -- 0.012756949974688
|
||||
"Blackman", -- 0.013191924552285
|
||||
--"Catrom", -- 0.013753536746706
|
||||
--"Hanning", -- 0.013761314529647
|
||||
--"Hermite", -- 0.013850225205266
|
||||
"Sinc", -- 0.014095824314306
|
||||
"Lanczos", -- 0.014244299255442
|
||||
}
|
||||
end
|
||||
local filters = options.downsampling_filters
|
||||
local unstable_region_offset = 8
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
local downsampling_filter = filters[torch.random(1, #filters)]
|
||||
local y = preprocess(src, size, options)
|
||||
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
||||
local down_scale = 1.0 / scale
|
||||
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downscale_filter),
|
||||
y:size(2) * down_scale, downsampling_filter),
|
||||
y:size(3), y:size(2))
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
|
@ -184,7 +168,8 @@ function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
|||
if level == 1 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(65, 85)},
|
||||
size, offset, n, options)
|
||||
elseif level == 2 then
|
||||
elseif level == 2 or level == 3 then
|
||||
-- level 2/3 adjusting by -nr_rate. for level3, -nr_rate=1
|
||||
local r = torch.uniform()
|
||||
if r > 0.6 then
|
||||
return pairwise_transform.jpeg_(src, {torch.random(27, 70)},
|
||||
|
|
|
@ -24,7 +24,7 @@ cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
|||
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||
cmd:option("-model_dir", "./models", 'model directory')
|
||||
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-noise_level", 1, '(1|2|3)')
|
||||
cmd:option("-style", "art", '(art|photo)')
|
||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
|
||||
|
@ -42,16 +42,24 @@ cmd:option("-epoch", 30, 'number of epochs to run')
|
|||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
|
||||
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
|
||||
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
|
||||
cmd:option("-validation_crops", 160, 'number of cropping region per image in validation')
|
||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||
cmd:option("-save_history", 0, 'save all model (0|1)')
|
||||
cmd:option("-plot", 0, 'plot loss chart(0|1)')
|
||||
cmd:option("-downsampling_filters", "Box,Catrom", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
settings[k] = v
|
||||
end
|
||||
if settings.plot == 1 then
|
||||
settings.plot = true
|
||||
require 'gnuplot'
|
||||
else
|
||||
settings.plot = false
|
||||
end
|
||||
if settings.save_history == 1 then
|
||||
settings.save_history = true
|
||||
else
|
||||
|
@ -88,10 +96,14 @@ if not (settings.style == "art" or
|
|||
settings.style == "photo") then
|
||||
error(string.format("unknown style: %s", settings.style))
|
||||
end
|
||||
|
||||
if settings.thread > 0 then
|
||||
torch.setnumthreads(tonumber(settings.thread))
|
||||
end
|
||||
if settings.downsampling_filters and settings.downsampling_filters:len() > 0 then
|
||||
settings.downsampling_filters = settings.downsampling_filters:split(",")
|
||||
else
|
||||
settings.downsampling_filters = {"Box", "Lanczos", "Catrom"}
|
||||
end
|
||||
|
||||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
|
|
@ -17,6 +17,16 @@ if cudnn and cudnn.SpatialConvolution then
|
|||
end
|
||||
end
|
||||
|
||||
function nn.SpatialConvolutionMM:clearState()
|
||||
if self.gradWeight then
|
||||
self.gradWeight = torch.Tensor(self.nOutputPlane, self.nInputPlane * self.kH * self.kW):typeAs(self.gradWeight):zero()
|
||||
end
|
||||
if self.gradBias then
|
||||
self.gradBias = torch.Tensor(self.nOutputPlane):typeAs(self.gradBias):zero()
|
||||
end
|
||||
return nn.utils.clear(self, 'finput', 'fgradInput', '_input', '_gradOutput', 'output', 'gradInput')
|
||||
end
|
||||
|
||||
function srcnn.channels(model)
|
||||
return model:get(model:size() - 1).weight:size(1)
|
||||
end
|
||||
|
|
|
@ -19,8 +19,7 @@ else
|
|||
require 'LeakyReLU'
|
||||
require 'LeakyReLU_deprecated'
|
||||
require 'DepthExpand2x'
|
||||
require 'WeightedMSECriterion'
|
||||
require 'PSNRCriterion'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
require 'cleanup_model'
|
||||
return w2nn
|
||||
end
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
models/anime_style_art/noise3_model.json
Normal file
BIN
models/anime_style_art/noise3_model.json
Normal file
Binary file not shown.
BIN
models/anime_style_art/noise3_model.t7
Normal file
BIN
models/anime_style_art/noise3_model.t7
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
models/anime_style_art_rgb/noise3_model.json
Normal file
BIN
models/anime_style_art_rgb/noise3_model.json
Normal file
Binary file not shown.
BIN
models/anime_style_art_rgb/noise3_model.t7
Normal file
BIN
models/anime_style_art_rgb/noise3_model.t7
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
models/photo/noise3_model.json
Normal file
BIN
models/photo/noise3_model.json
Normal file
Binary file not shown.
BIN
models/photo/noise3_model.t7
Normal file
BIN
models/photo/noise3_model.t7
Normal file
Binary file not shown.
Binary file not shown.
|
@ -1,25 +0,0 @@
|
|||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
|
||||
require 'w2nn'
|
||||
torch.setdefaulttensortype("torch.FloatTensor")
|
||||
|
||||
local cmd = torch.CmdLine()
|
||||
cmd:text()
|
||||
cmd:text("cleanup model")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-model", "./model.t7", 'path of model file')
|
||||
cmd:option("-iformat", "binary", 'input format')
|
||||
cmd:option("-oformat", "binary", 'output format')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
local model = torch.load(opt.model, opt.iformat)
|
||||
if model then
|
||||
w2nn.cleanup_model(model)
|
||||
model:cuda()
|
||||
model:evaluate()
|
||||
torch.save(opt.model, model, opt.oformat)
|
||||
else
|
||||
error("model not found")
|
||||
end
|
|
@ -24,11 +24,6 @@ function export(model, output)
|
|||
}
|
||||
table.insert(jmodules, jmod)
|
||||
end
|
||||
jmodules[1].color = "RGB"
|
||||
jmodules[1].gamma = 0
|
||||
jmodules[#jmodules].color = "RGB"
|
||||
jmodules[#jmodules].gamma = 0
|
||||
|
||||
local fp = io.open(output, "w")
|
||||
if not fp then
|
||||
error("IO Error: " .. output)
|
||||
|
|
41
train.lua
41
train.lua
|
@ -78,7 +78,11 @@ local function create_criterion(model)
|
|||
weight[3]:fill(0.11448 * 3) -- B
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
else
|
||||
return nn.MSECriterion():cuda()
|
||||
local offset = reconstruct.offset_size(model)
|
||||
local output_w = settings.crop_size - offset * 2
|
||||
local weight = torch.Tensor(1, output_w * output_w)
|
||||
weight[1]:fill(1.0)
|
||||
return w2nn.ClippedWeightedHuberCriterion(weight, 0.1, {0.0, 1.0}):cuda()
|
||||
end
|
||||
end
|
||||
local function transformer(x, is_validation, n, offset)
|
||||
|
@ -91,8 +95,8 @@ local function transformer(x, is_validation, n, offset)
|
|||
local active_cropping_rate = nil
|
||||
local active_cropping_tries = nil
|
||||
if is_validation then
|
||||
active_cropping_rate = 0
|
||||
active_cropping_tries = 0
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
random_color_noise_rate = 0.0
|
||||
random_overlay_rate = 0.0
|
||||
else
|
||||
|
@ -108,6 +112,7 @@ local function transformer(x, is_validation, n, offset)
|
|||
settings.crop_size, offset,
|
||||
n,
|
||||
{
|
||||
downsampling_filters = settings.downsampling_filters,
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
|
@ -153,8 +158,14 @@ local function resampling(x, y, train_x, transformer, input_size, target_size)
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function plot(train, valid)
|
||||
gnuplot.plot({
|
||||
{'training', torch.Tensor(train), '-'},
|
||||
{'validation', torch.Tensor(valid), '-'}})
|
||||
end
|
||||
local function train()
|
||||
local hist_train = {}
|
||||
local hist_valid = {}
|
||||
local LR_MIN = 1.0e-5
|
||||
local model = srcnn.create(settings.method, settings.backend, settings.color)
|
||||
local offset = reconstruct.offset_size(model)
|
||||
|
@ -162,6 +173,7 @@ local function train()
|
|||
return transformer(x, is_validation, n, offset)
|
||||
end
|
||||
local criterion = create_criterion(model)
|
||||
local eval_metric = w2nn.PSNRCriterion():cuda()
|
||||
local x = torch.load(settings.images)
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||
local adam_config = {
|
||||
|
@ -175,7 +187,7 @@ local function train()
|
|||
elseif settings.color == "rgb" then
|
||||
ch = 3
|
||||
end
|
||||
local best_score = 100000.0
|
||||
local best_score = 0.0
|
||||
print("# make validation-set")
|
||||
local valid_xy = make_validation_set(valid_x, pairwise_func,
|
||||
settings.validation_crops,
|
||||
|
@ -196,19 +208,24 @@ local function train()
|
|||
print("# " .. epoch)
|
||||
resampling(x, y, train_x, pairwise_func)
|
||||
for i = 1, settings.inner_epoch do
|
||||
print(minibatch_adam(model, criterion, x, y, adam_config))
|
||||
local train_score = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
||||
print(train_score)
|
||||
model:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, criterion, valid_xy)
|
||||
if score < best_score then
|
||||
local score = validate(model, eval_metric, valid_xy)
|
||||
|
||||
table.insert(hist_train, train_score.PSNR)
|
||||
table.insert(hist_valid, score)
|
||||
if settings.plot then
|
||||
plot(hist_train, hist_valid)
|
||||
end
|
||||
if score > best_score then
|
||||
local test_image = image_loader.load_float(settings.test) -- reload
|
||||
lrd_count = 0
|
||||
best_score = score
|
||||
print("* update best model")
|
||||
if settings.save_history then
|
||||
local model_clone = model:clone()
|
||||
w2nn.cleanup_model(model_clone)
|
||||
torch.save(string.format(settings.model_file, epoch, i), model_clone)
|
||||
torch.save(string.format(settings.model_file, epoch, i), model:clearState(), "ascii")
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.%d-%d.png"):format(settings.noise_level,
|
||||
|
@ -221,7 +238,7 @@ local function train()
|
|||
save_test_scale(model, test_image, log)
|
||||
end
|
||||
else
|
||||
torch.save(settings.model_file, model)
|
||||
torch.save(settings.model_file, model:clearState(), "ascii")
|
||||
if settings.method == "noise" then
|
||||
local log = path.join(settings.model_dir,
|
||||
("noise%d_best.png"):format(settings.noise_level))
|
||||
|
|
5
train.sh
5
train.sh
|
@ -3,10 +3,5 @@
|
|||
th convert_data.lua
|
||||
|
||||
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||
|
|
|
@ -2,11 +2,8 @@
|
|||
|
||||
th convert_data.lua -style photo -data_dir ./data/photo -model_dir models/photo
|
||||
|
||||
th train.lua -style photo -method scale -data_dir ./data/photo -model_dir models/photo_uk -test work/scale_test_photo.png -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.1 -validation_crops 160
|
||||
th tools/cleanup_model.lua -model models/photo/scale2.0x_model.t7 -oformat ascii
|
||||
th train.lua -style photo -method scale -data_dir ./data/photo -model_dir models/photo -test work/scale_test_photo.png -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.1 -validation_crops 160
|
||||
|
||||
th train.lua -style photo -method noise -noise_level 1 -data_dir ./data/photo -model_dir models/photo -test work/noise_test_photo.jpg -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.5 -validation_crops 160 -nr_rate 0.6 -epoch 33
|
||||
th tools/cleanup_model.lua -model models/photo/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -style photo -method noise -noise_level 2 -data_dir ./data/photo -model_dir models/photo -test work/noise_test_photo.jpg -color rgb -thread 4 -backend cudnn -random_unsharp_mask_rate 0.5 -validation_crops 160 -nr_rate 0.8 -epoch 38
|
||||
th tools/cleanup_model.lua -model models/photo/noise2_model.t7 -oformat ascii
|
||||
|
|
51
waifu2x.lua
51
waifu2x.lua
|
@ -69,7 +69,8 @@ local function convert_image(opt)
|
|||
print(opt.o .. ": " .. (sys.clock() - t) .. " sec")
|
||||
end
|
||||
local function convert_frames(opt)
|
||||
local model_path, noise1_model, noise2_model, scale_model
|
||||
local model_path, scale_model
|
||||
local noise_model = {}
|
||||
local scale_f, image_f
|
||||
if opt.tta == 1 then
|
||||
scale_f = reconstruct.scale_tta
|
||||
|
@ -84,16 +85,10 @@ local function convert_frames(opt)
|
|||
if not scale_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise" and opt.noise_level == 1 then
|
||||
model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
noise1_model = torch.load(model_path, "ascii")
|
||||
if not noise1_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
noise2_model = torch.load(model_path, "ascii")
|
||||
if not noise2_model then
|
||||
elseif opt.m == "noise" then
|
||||
model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
|
||||
noise_model[opt.noise_level] = torch.load(model_path, "ascii")
|
||||
if not noise_model[opt.noise_level] then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.m == "noise_scale" then
|
||||
|
@ -102,18 +97,10 @@ local function convert_frames(opt)
|
|||
if not scale_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
if opt.noise_level == 1 then
|
||||
model_path = path.join(opt.model_dir, "noise1_model.t7")
|
||||
noise1_model = torch.load(model_path, "ascii")
|
||||
if not noise1_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
elseif opt.noise_level == 2 then
|
||||
model_path = path.join(opt.model_dir, "noise2_model.t7")
|
||||
noise2_model = torch.load(model_path, "ascii")
|
||||
if not noise2_model then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
model_path = path.join(opt.model_dir, string.format("noise%d_model.t7", opt.noise_level))
|
||||
noise_model[opt.noise_level] = torch.load(model_path, "ascii")
|
||||
if not noise_model[opt.noise_level] then
|
||||
error("Load Error: " .. model_path)
|
||||
end
|
||||
end
|
||||
local fp = io.open(opt.l)
|
||||
|
@ -130,24 +117,16 @@ local function convert_frames(opt)
|
|||
if opt.resume == 0 or path.exists(string.format(opt.o, i)) == false then
|
||||
local x, alpha = image_loader.load_float(lines[i])
|
||||
local new_x = nil
|
||||
if opt.m == "noise" and opt.noise_level == 1 then
|
||||
new_x = image_f(noise1_model, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha)
|
||||
elseif opt.m == "noise" and opt.noise_level == 2 then
|
||||
new_x = image_f(noise2_model, x, opt.crop_size)
|
||||
if opt.m == "noise" then
|
||||
new_x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha)
|
||||
elseif opt.m == "scale" then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 1 then
|
||||
elseif opt.m == "noise_scale" then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
x = image_f(noise1_model, x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
elseif opt.m == "noise_scale" and opt.noise_level == 2 then
|
||||
x = alpha_util.make_border(x, alpha, reconstruct.offset_size(scale_model))
|
||||
x = image_f(noise2_model, x, opt.crop_size)
|
||||
x = image_f(noise_model[opt.noise_level], x, opt.crop_size)
|
||||
new_x = scale_f(scale_model, opt.scale, x, opt.crop_size)
|
||||
new_x = alpha_util.composite(new_x, alpha, scale_model)
|
||||
else
|
||||
|
@ -185,7 +164,7 @@ local function waifu2x()
|
|||
cmd:option("-depth", 8, 'bit-depth of the output image (8|16)')
|
||||
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
|
||||
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-noise_level", 1, '(1|2|3)')
|
||||
cmd:option("-crop_size", 128, 'patch size per process')
|
||||
cmd:option("-resume", 0, "skip existing files (0|1)")
|
||||
cmd:option("-thread", -1, "number of CPU threads")
|
||||
|
|
24
web.lua
24
web.lua
|
@ -38,12 +38,14 @@ if cudnn then
|
|||
end
|
||||
local ART_MODEL_DIR = path.join(ROOT, "models", "anime_style_art_rgb")
|
||||
local PHOTO_MODEL_DIR = path.join(ROOT, "models", "photo")
|
||||
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local art_noise1_model = torch.load(path.join(ART_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local art_noise2_model = torch.load(path.join(ART_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local art_scale2_model = torch.load(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local art_noise3_model = torch.load(path.join(ART_MODEL_DIR, "noise3_model.t7"), "ascii")
|
||||
local photo_scale2_model = torch.load(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), "ascii")
|
||||
local photo_noise1_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), "ascii")
|
||||
local photo_noise2_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), "ascii")
|
||||
local photo_noise3_model = torch.load(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), "ascii")
|
||||
local CLEANUP_MODEL = false -- if you are using the low memory GPU, you could use this flag.
|
||||
local CACHE_DIR = path.join(ROOT, "cache")
|
||||
local MAX_NOISE_IMAGE = 2560 * 2560
|
||||
|
@ -114,7 +116,7 @@ local function get_image(req)
|
|||
end
|
||||
local function cleanup_model(model)
|
||||
if CLEANUP_MODEL then
|
||||
w2nn.cleanup_model(model) -- release GPU memory
|
||||
model:clearState() -- release GPU memory
|
||||
end
|
||||
end
|
||||
local function convert(x, alpha, options)
|
||||
|
@ -151,9 +153,12 @@ local function convert(x, alpha, options)
|
|||
elseif options.method == "noise1" then
|
||||
x = reconstruct.image(art_noise1_model, x)
|
||||
cleanup_model(art_noise1_model)
|
||||
else -- options.method == "noise2"
|
||||
elseif options.method == "noise2" then
|
||||
x = reconstruct.image(art_noise2_model, x)
|
||||
cleanup_model(art_noise2_model)
|
||||
elseif options.method == "noise3" then
|
||||
x = reconstruct.image(art_noise3_model, x)
|
||||
cleanup_model(art_noise3_model)
|
||||
end
|
||||
else -- photo
|
||||
if options.border then
|
||||
|
@ -174,6 +179,9 @@ local function convert(x, alpha, options)
|
|||
elseif options.method == "noise2" then
|
||||
x = reconstruct.image(photo_noise2_model, x)
|
||||
cleanup_model(photo_noise2_model)
|
||||
elseif options.method == "noise3" then
|
||||
x = reconstruct.image(photo_noise3_model, x)
|
||||
cleanup_model(photo_noise3_model)
|
||||
end
|
||||
end
|
||||
image_loader.save_png(cache_file, x)
|
||||
|
@ -229,17 +237,25 @@ function APIHandler:post()
|
|||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 2 then
|
||||
prefix = style .. "_noise1_"
|
||||
prefix = style .. "_noise2_"
|
||||
x = convert(x, alpha, {method = "noise2", style = style,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
elseif noise == 3 then
|
||||
prefix = style .. "_noise3_"
|
||||
x = convert(x, alpha, {method = "noise3", style = style,
|
||||
prefix = prefix .. hash,
|
||||
alpha_prefix = alpha_prefix, border = border})
|
||||
border = false
|
||||
end
|
||||
if scale == 1 or scale == 2 then
|
||||
if noise == 1 then
|
||||
prefix = style .. "_noise1_scale_"
|
||||
elseif noise == 2 then
|
||||
prefix = style .. "_noise2_scale_"
|
||||
elseif noise == 3 then
|
||||
prefix = style .. "_noise3_scale_"
|
||||
else
|
||||
prefix = style .. "_scale_"
|
||||
end
|
||||
|
|
|
@ -14,6 +14,7 @@ expect_jpeg: expect JPEG artifact
|
|||
nr_none: None
|
||||
nr_medium: Medium
|
||||
nr_high: High
|
||||
nr_highest: Highest
|
||||
nr_hint: "You need use noise reduction if image actually has noise or it may cause opposite effect."
|
||||
upscaling: Upscaling
|
||||
up_none: None
|
||||
|
|
|
@ -14,6 +14,7 @@ expect_jpeg: JPEGノイズを想定
|
|||
nr_none: なし
|
||||
nr_medium: 中
|
||||
nr_high: 高
|
||||
nr_highest: 最高
|
||||
nr_hint: "ノイズ除去は細部が消えることがあります。JPEGノイズがある場合に使用します。"
|
||||
upscaling: 拡大
|
||||
up_none: なし
|
||||
|
|
|
@ -106,6 +106,12 @@
|
|||
<%= t[:nr_high] %>
|
||||
</span>
|
||||
</label>
|
||||
<label>
|
||||
<input type="radio" name="noise" class="radio" value="3">
|
||||
<span class="r-text">
|
||||
<%= t[:nr_highest] %>
|
||||
</span>
|
||||
</label>
|
||||
</div>
|
||||
<div class="option-hint">
|
||||
<%= t[:nr_hint] %>
|
||||
|
|
Loading…
Reference in a new issue