cleanup
This commit is contained in:
parent
2b9753157a
commit
903d945652
|
@ -152,6 +152,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 -
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training Your Own Model
|
## Training Your Own Model
|
||||||
|
Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`.
|
||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,7 @@ local pairwise_transform = {}
|
||||||
|
|
||||||
local function random_half(src, p)
|
local function random_half(src, p)
|
||||||
p = p or 0.25
|
p = p or 0.25
|
||||||
--local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)]
|
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
|
||||||
local filter = "Box"
|
|
||||||
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
|
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
|
||||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||||
else
|
else
|
||||||
|
@ -163,8 +162,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||||
end
|
end
|
||||||
return batch
|
return batch
|
||||||
end
|
end
|
||||||
function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
function pairwise_transform.jpeg(src, style, level, size, offset, n, options)
|
||||||
if category == "anime_style_art" then
|
if style == "art" then
|
||||||
if level == 1 then
|
if level == 1 then
|
||||||
if torch.uniform() > 0.8 then
|
if torch.uniform() > 0.8 then
|
||||||
return pairwise_transform.jpeg_(src, {},
|
return pairwise_transform.jpeg_(src, {},
|
||||||
|
@ -200,7 +199,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
||||||
else
|
else
|
||||||
error("unknown noise level: " .. level)
|
error("unknown noise level: " .. level)
|
||||||
end
|
end
|
||||||
elseif category == "photo" then
|
elseif style == "photo" then
|
||||||
if level == 1 then
|
if level == 1 then
|
||||||
if torch.uniform() > 0.7 then
|
if torch.uniform() > 0.7 then
|
||||||
return pairwise_transform.jpeg_(src, {},
|
return pairwise_transform.jpeg_(src, {},
|
||||||
|
@ -225,7 +224,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
||||||
error("unknown noise level: " .. level)
|
error("unknown noise level: " .. level)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
error("unknown category: " .. category)
|
error("unknown style: " .. style)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -239,7 +238,7 @@ function pairwise_transform.test_jpeg(src)
|
||||||
}
|
}
|
||||||
for i = 1, 9 do
|
for i = 1, 9 do
|
||||||
local xy = pairwise_transform.jpeg(src,
|
local xy = pairwise_transform.jpeg(src,
|
||||||
"anime_style_art",
|
"art",
|
||||||
torch.random(1, 2),
|
torch.random(1, 2),
|
||||||
128, 7, 1, options)
|
128, 7, 1, options)
|
||||||
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
|
image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1})
|
||||||
|
|
|
@ -17,30 +17,30 @@ local cmd = torch.CmdLine()
|
||||||
cmd:text()
|
cmd:text()
|
||||||
cmd:text("waifu2x-training")
|
cmd:text("waifu2x-training")
|
||||||
cmd:text("Options:")
|
cmd:text("Options:")
|
||||||
cmd:option("-seed", 11, 'fixed input seed')
|
cmd:option("-seed", 11, 'RNG seed')
|
||||||
cmd:option("-data_dir", "./data", 'data directory')
|
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||||
cmd:option("-test", "images/miku_small.png", 'test image file')
|
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
||||||
cmd:option("-model_dir", "./models", 'model directory')
|
cmd:option("-model_dir", "./models", 'model directory')
|
||||||
cmd:option("-method", "scale", '(noise|scale)')
|
cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||||
cmd:option("-noise_level", 1, '(1|2)')
|
cmd:option("-noise_level", 1, '(1|2)')
|
||||||
cmd:option("-category", "anime_style_art", '(anime_style_art|photo)')
|
cmd:option("-style", "art", '(art|photo)')
|
||||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||||
cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)')
|
cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
|
||||||
cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)')
|
cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
|
||||||
cmd:option("-scale", 2.0, 'scale')
|
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||||
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
|
cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
|
||||||
cmd:option("-crop_size", 128, 'crop size')
|
cmd:option("-crop_size", 46, 'crop size')
|
||||||
cmd:option("-max_size", 512, 'crop if image size larger then this value.')
|
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
||||||
cmd:option("-batch_size", 2, 'mini batch size')
|
cmd:option("-batch_size", 8, 'mini batch size')
|
||||||
cmd:option("-epoch", 200, 'epoch')
|
cmd:option("-epoch", 200, 'number of total epochs to run')
|
||||||
cmd:option("-thread", -1, 'number of CPU threads')
|
cmd:option("-thread", -1, 'number of CPU threads')
|
||||||
cmd:option("-jpeg_sampling_factors", 444, '(444|422)')
|
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
||||||
cmd:option("-validation_ratio", 0.1, 'validation ratio')
|
cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
|
||||||
cmd:option("-validation_crops", 40, 'number of crop region in validation')
|
cmd:option("-validation_crops", 80, 'number of region per image in validation')
|
||||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||||
cmd:option("-active_cropping_tries", 20, 'active cropping tries')
|
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||||
|
|
||||||
local opt = cmd:parse(arg)
|
local opt = cmd:parse(arg)
|
||||||
for k, v in pairs(opt) do
|
for k, v in pairs(opt) do
|
||||||
|
@ -64,9 +64,9 @@ end
|
||||||
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
||||||
error("scale must be mod-2")
|
error("scale must be mod-2")
|
||||||
end
|
end
|
||||||
if not (settings.category == "anime_style_art" or
|
if not (settings.style == "art" or
|
||||||
settings.category == "photo") then
|
settings.style == "photo") then
|
||||||
error(string.format("unknown category: %s", settings.category))
|
error(string.format("unknown style: %s", settings.style))
|
||||||
end
|
end
|
||||||
if settings.random_half == 1 then
|
if settings.random_half == 1 then
|
||||||
settings.random_half = true
|
settings.random_half = true
|
||||||
|
|
|
@ -87,7 +87,7 @@ local function transformer(x, is_validation, n, offset)
|
||||||
if is_validation == nil then is_validation = false end
|
if is_validation == nil then is_validation = false end
|
||||||
local color_noise = nil
|
local color_noise = nil
|
||||||
local overlay = nil
|
local overlay = nil
|
||||||
local active_cropping_ratio = nil
|
local active_cropping_rate = nil
|
||||||
local active_cropping_tries = nil
|
local active_cropping_tries = nil
|
||||||
|
|
||||||
if is_validation then
|
if is_validation then
|
||||||
|
@ -117,7 +117,7 @@ local function transformer(x, is_validation, n, offset)
|
||||||
})
|
})
|
||||||
elseif settings.method == "noise" then
|
elseif settings.method == "noise" then
|
||||||
return pairwise_transform.jpeg(x,
|
return pairwise_transform.jpeg(x,
|
||||||
settings.category,
|
settings.style,
|
||||||
settings.noise_level,
|
settings.noise_level,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
n,
|
n,
|
||||||
|
@ -142,7 +142,7 @@ local function train()
|
||||||
local criterion = create_criterion(model)
|
local criterion = create_criterion(model)
|
||||||
local x = torch.load(settings.images)
|
local x = torch.load(settings.images)
|
||||||
local lrd_count = 0
|
local lrd_count = 0
|
||||||
local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x))
|
local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x))
|
||||||
local adam_config = {
|
local adam_config = {
|
||||||
learningRate = settings.learning_rate,
|
learningRate = settings.learning_rate,
|
||||||
xBatchSize = settings.batch_size,
|
xBatchSize = settings.batch_size,
|
||||||
|
|
13
train.sh
13
train.sh
|
@ -2,12 +2,11 @@
|
||||||
|
|
||||||
th convert_data.lua
|
th convert_data.lua
|
||||||
|
|
||||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
|
th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4
|
||||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
||||||
|
|
||||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80
|
th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||||
th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii
|
||||||
|
|
||||||
th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80
|
|
||||||
th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii
|
|
||||||
|
|
||||||
|
th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4
|
||||||
|
th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
#!/bin/sh
|
#!/bin/sh
|
||||||
|
|
||||||
th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
|
th convert_data.lua -data_dir ./data/ukbench
|
||||||
th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii
|
|
||||||
|
|
||||||
th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg
|
|
||||||
th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii
|
|
||||||
|
|
||||||
th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png
|
|
||||||
th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii
|
|
||||||
|
|
||||||
|
th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg -thread 4
|
||||||
|
th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii
|
||||||
|
|
|
@ -128,11 +128,11 @@ local function waifu2x()
|
||||||
cmd:text()
|
cmd:text()
|
||||||
cmd:text("waifu2x")
|
cmd:text("waifu2x")
|
||||||
cmd:text("Options:")
|
cmd:text("Options:")
|
||||||
cmd:option("-i", "images/miku_small.png", 'path of the input image')
|
cmd:option("-i", "images/miku_small.png", 'path to input image')
|
||||||
cmd:option("-l", "", 'path of the image-list')
|
cmd:option("-l", "", 'path to image-list.txt')
|
||||||
cmd:option("-scale", 2, 'scale factor')
|
cmd:option("-scale", 2, 'scale factor')
|
||||||
cmd:option("-o", "(auto)", 'path of the output file')
|
cmd:option("-o", "(auto)", 'path to output file')
|
||||||
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory')
|
cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory')
|
||||||
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)')
|
||||||
cmd:option("-noise_level", 1, '(1|2)')
|
cmd:option("-noise_level", 1, '(1|2)')
|
||||||
cmd:option("-crop_size", 128, 'patch size per process')
|
cmd:option("-crop_size", 128, 'patch size per process')
|
||||||
|
|
Loading…
Reference in a new issue