cleanup
This commit is contained in:
parent
2b9753157a
commit
903d945652
|
@ -152,6 +152,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
|
|||
```
|
||||
|
||||
## Training Your Own Model
|
||||
Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
|
||||
|
||||
### Data Preparation
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ local pairwise_transform = {}
|
|||
|
||||
local function random_half(src, p)
|
||||
p = p or 0.25
|
||||
--local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
|
||||
local filter = "Box"
|
||||
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
|
||||
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
|
@ -163,8 +162,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|||
end
|
||||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
||||
if category == "anime_style_art" then
|
||||
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
||||
if style == "art" then
|
||||
if level == 1 then
|
||||
if torch.uniform() > 0.8 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
|
@ -200,7 +199,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
|||
else
|
||||
error("unknown noise level: " .. level)
|
||||
end
|
||||
elseif category == "photo" then
|
||||
elseif style == "photo" then
|
||||
if level == 1 then
|
||||
if torch.uniform() > 0.7 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
|
@ -225,7 +224,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
|||
error("unknown noise level: " .. level)
|
||||
end
|
||||
else
|
||||
error("unknown category: " .. category)
|
||||
error("unknown style: " .. style)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -239,7 +238,7 @@ function pairwise_transform.test_jpeg(src)
|
|||
}
|
||||
for i = 1, 9 do
|
||||
local xy = pairwise_transform.jpeg(src,
|
||||
"anime_style_art",
|
||||
"art",
|
||||
torch.random(1, 2),
|
||||
128, 7, 1, options)
|
||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
|
||||
|
|
|
@ -17,30 +17,30 @@ local cmd = torch.CmdLine()
|
|||
cmd:text()
|
||||
cmd:text("waifu2x-training")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-seed", 11, 'fixed input seed')
|
||||
cmd:option("-data_dir", "./data", 'data directory')
|
||||
cmd:option("-seed", 11, 'RNG seed')
|
||||
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||
cmd:option("-test", "images/miku_small.png", 'test image file')
|
||||
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||
cmd:option("-model_dir", "./models", 'model directory')
|
||||
cmd:option("-method", "scale", '(noise|scale)')
|
||||
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
|
||||
cmd:option("-style", "art", '(art|photo)')
|
||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||
cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
|
||||
cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
|
||||
cmd:option("-scale", 2.0, 'scale')
|
||||
cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
|
||||
cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
|
||||
cmd:option("-crop_size", 128, 'crop size')
|
||||
cmd:option("-max_size", 512, 'crop if image size larger then this value.')
|
||||
cmd:option("-batch_size", 2, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'epoch')
|
||||
cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
|
||||
cmd:option("-crop_size", 46, 'crop size')
|
||||
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
||||
cmd:option("-batch_size", 8, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'number of total epochs to run')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
|
||||
cmd:option("-validation_ratio", 0.1, 'validation ratio')
|
||||
cmd:option("-validation_crops", 40, 'number of crop region in validation')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
||||
cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
|
||||
cmd:option("-validation_crops", 80, 'number of region per image in validation')
|
||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 20, 'active cropping tries')
|
||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
|
||||
local opt = cmd:parse(arg)
|
||||
for k, v in pairs(opt) do
|
||||
|
@ -64,9 +64,9 @@ end
|
|||
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
||||
error("scale must be mod-2")
|
||||
end
|
||||
if not (settings.category == "anime_style_art" or
|
||||
settings.category == "photo") then
|
||||
error(string.format("unknown category: %s", settings.category))
|
||||
if not (settings.style == "art" or
|
||||
settings.style == "photo") then
|
||||
error(string.format("unknown style: %s", settings.style))
|
||||
end
|
||||
if settings.random_half == 1 then
|
||||
settings.random_half = true
|
||||
|
|
|
@ -87,7 +87,7 @@ local function transformer(x, is_validation, n, offset)
|
|||
if is_validation == nil then is_validation = false end
|
||||
local color_noise = nil
|
||||
local overlay = nil
|
||||
local active_cropping_ratio = nil
|
||||
local active_cropping_rate = nil
|
||||
local active_cropping_tries = nil
|
||||
|
||||
if is_validation then
|
||||
|
@ -117,7 +117,7 @@ local function transformer(x, is_validation, n, offset)
|
|||
})
|
||||
elseif settings.method == "noise" then
|
||||
return pairwise_transform.jpeg(x,
|
||||
settings.category,
|
||||
settings.style,
|
||||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
|
@ -142,7 +142,7 @@ local function train()
|
|||
local criterion = create_criterion(model)
|
||||
local x = torch.load(settings.images)
|
||||
local lrd_count = 0
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
|
||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||
local adam_config = {
|
||||
learningRate = settings.learning_rate,
|
||||
xBatchSize = settings.batch_size,
|
||||
|
|
13
train.sh
13
train.sh
|
@ -2,12 +2,11 @@
|
|||
|
||||
th convert_data.lua
|
||||
|
||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
|
||||
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||
|
|
|
@ -1,11 +1,6 @@
|
|||
#!/bin/sh
|
||||
|
||||
th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
|
||||
th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
|
||||
th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
|
||||
|
||||
th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
|
||||
th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii
|
||||
th convert_data.lua -data_dir ./data/ukbench
|
||||
|
||||
th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg -thread 4
|
||||
th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii
|
||||
|
|
|
@ -128,11 +128,11 @@ local function waifu2x()
|
|||
cmd:text()
|
||||
cmd:text("waifu2x")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-i", "images/miku_small.png", 'path of the input image')
|
||||
cmd:option("-l", "", 'path of the image-list')
|
||||
cmd:option("-i", "images/miku_small.png", 'path to input image')
|
||||
cmd:option("-l", "", 'path to image-list.txt')
|
||||
cmd:option("-scale", 2, 'scale factor')
|
||||
cmd:option("-o", "(auto)", 'path of the output file')
|
||||
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
|
||||
cmd:option("-o", "(auto)", 'path to output file')
|
||||
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
|
||||
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
||||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-crop_size", 128, 'patch size per process')
|
||||
|
|
Loading…
Reference in a new issue