1
0
Fork 0
mirror of synced 2024-05-19 20:32:22 +12:00
This commit is contained in:
nagadomi 2015-11-06 10:08:54 +09:00
parent 2b9753157a
commit 903d945652
7 changed files with 43 additions and 49 deletions

View file

@ -152,6 +152,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
```
## Training Your Own Model
Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
### Data Preparation

View file

@ -7,8 +7,7 @@ local pairwise_transform = {}
local function random_half(src, p)
p = p or 0.25
--local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
local filter = "Box"
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else
@ -163,8 +162,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
end
return batch
end
function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
if category == "anime_style_art" then
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
if style == "art" then
if level == 1 then
if torch.uniform() > 0.8 then
return pairwise_transform.jpeg_(src, {},
@ -200,7 +199,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
else
error("unknown noise level: " .. level)
end
elseif category == "photo" then
elseif style == "photo" then
if level == 1 then
if torch.uniform() > 0.7 then
return pairwise_transform.jpeg_(src, {},
@ -225,7 +224,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
error("unknown noise level: " .. level)
end
else
error("unknown category: " .. category)
error("unknown style: " .. style)
end
end
@ -239,7 +238,7 @@ function pairwise_transform.test_jpeg(src)
}
for i = 1, 9 do
local xy = pairwise_transform.jpeg(src,
"anime_style_art",
"art",
torch.random(1, 2),
128, 7, 1, options)
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})

View file

@ -17,30 +17,30 @@ local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-training")
cmd:text("Options:")
cmd:option("-seed", 11, 'fixed input seed')
cmd:option("-data_dir", "./data", 'data directory')
cmd:option("-seed", 11, 'RNG seed')
cmd:option("-data_dir", "./data", 'path to data directory')
cmd:option("-backend", "cunn", '(cunn|cudnn)')
cmd:option("-test", "images/miku_small.png", 'test image file')
cmd:option("-test", "images/miku_small.png", 'path to test image')
cmd:option("-model_dir", "./models", 'model directory')
cmd:option("-method", "scale", '(noise|scale)')
cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
cmd:option("-scale", 2.0, 'scale')
cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
cmd:option("-crop_size", 128, 'crop size')
cmd:option("-max_size", 512, 'crop if image size larger then this value.')
cmd:option("-batch_size", 2, 'mini batch size')
cmd:option("-epoch", 200, 'epoch')
cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
cmd:option("-crop_size", 46, 'crop size')
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
cmd:option("-batch_size", 8, 'mini batch size')
cmd:option("-epoch", 200, 'number of total epochs to run')
cmd:option("-thread", -1, 'number of CPU threads')
cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
cmd:option("-validation_ratio", 0.1, 'validation ratio')
cmd:option("-validation_crops", 40, 'number of crop region in validation')
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
cmd:option("-validation_crops", 80, 'number of region per image in validation')
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 20, 'active cropping tries')
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
local opt = cmd:parse(arg)
for k, v in pairs(opt) do
@ -64,9 +64,9 @@ end
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
error("scale must be mod-2")
end
if not (settings.category == "anime_style_art" or
settings.category == "photo") then
error(string.format("unknown category: %s", settings.category))
if not (settings.style == "art" or
settings.style == "photo") then
error(string.format("unknown style: %s", settings.style))
end
if settings.random_half == 1 then
settings.random_half = true

View file

@ -87,7 +87,7 @@ local function transformer(x, is_validation, n, offset)
if is_validation == nil then is_validation = false end
local color_noise = nil
local overlay = nil
local active_cropping_ratio = nil
local active_cropping_rate = nil
local active_cropping_tries = nil
if is_validation then
@ -117,7 +117,7 @@ local function transformer(x, is_validation, n, offset)
})
elseif settings.method == "noise" then
return pairwise_transform.jpeg(x,
settings.category,
settings.style,
settings.noise_level,
settings.crop_size, offset,
n,
@ -142,7 +142,7 @@ local function train()
local criterion = create_criterion(model)
local x = torch.load(settings.images)
local lrd_count = 0
local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
local adam_config = {
learningRate = settings.learning_rate,
xBatchSize = settings.batch_size,

View file

@ -2,12 +2,11 @@
th convert_data.lua
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii

View file

@ -1,11 +1,6 @@
#!/bin/sh
th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii
th convert_data.lua -data_dir ./data/ukbench
th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg -thread 4
th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii

View file

@ -128,11 +128,11 @@ local function waifu2x()
cmd:text()
cmd:text("waifu2x")
cmd:text("Options:")
cmd:option("-i", "images/miku_small.png", 'path of the input image')
cmd:option("-l", "", 'path of the image-list')
cmd:option("-i", "images/miku_small.png", 'path to input image')
cmd:option("-l", "", 'path to image-list.txt')
cmd:option("-scale", 2, 'scale factor')
cmd:option("-o", "(auto)", 'path of the output file')
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
cmd:option("-o", "(auto)", 'path to output file')
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-crop_size", 128, 'patch size per process')