diff --git a/README.md b/README.md index fb86504..6acf65b 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ avconv -f image2 -r 24 -i new_frames/%d.png -i audio.mp3 -r 24 -vcodec libx264 - ``` ## Training Your Own Model +Notes: If you have cuDNN library, you can use cudnn kernel with `-backend cudnn` option. And you can convert trained cudnn model to cunn model with `tools/cudnn2cunn.lua`. ### Data Preparation diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index 52bb3be..d4efbc5 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -7,8 +7,7 @@ local pairwise_transform = {} local function random_half(src, p) p = p or 0.25 - --local filter = ({"Box","Blackman", "SincFast", "Jinc"})[torch.random(1, 4)] - local filter = "Box" + local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)] if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) else @@ -163,8 +162,8 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options) end return batch end -function pairwise_transform.jpeg(src, category, level, size, offset, n, options) - if category == "anime_style_art" then +function pairwise_transform.jpeg(src, style, level, size, offset, n, options) + if style == "art" then if level == 1 then if torch.uniform() > 0.8 then return pairwise_transform.jpeg_(src, {}, @@ -200,7 +199,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options) else error("unknown noise level: " .. level) end - elseif category == "photo" then + elseif style == "photo" then if level == 1 then if torch.uniform() > 0.7 then return pairwise_transform.jpeg_(src, {}, @@ -225,7 +224,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options) error("unknown noise level: " .. level) end else - error("unknown category: " .. category) + error("unknown style: " .. style) end end @@ -239,7 +238,7 @@ function pairwise_transform.test_jpeg(src) } for i = 1, 9 do local xy = pairwise_transform.jpeg(src, - "anime_style_art", + "art", torch.random(1, 2), 128, 7, 1, options) image.display({image = xy[1][1], legend = "y:" .. (i * 10), min=0, max=1}) diff --git a/lib/settings.lua b/lib/settings.lua index 3ce097c..ab94078 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -17,30 +17,30 @@ local cmd = torch.CmdLine() cmd:text() cmd:text("waifu2x-training") cmd:text("Options:") -cmd:option("-seed", 11, 'fixed input seed') -cmd:option("-data_dir", "./data", 'data directory') +cmd:option("-seed", 11, 'RNG seed') +cmd:option("-data_dir", "./data", 'path to data directory') cmd:option("-backend", "cunn", '(cunn|cudnn)') -cmd:option("-test", "images/miku_small.png", 'test image file') +cmd:option("-test", "images/miku_small.png", 'path to test image') cmd:option("-model_dir", "./models", 'model directory') -cmd:option("-method", "scale", '(noise|scale)') +cmd:option("-method", "scale", 'method to training (noise|scale)') cmd:option("-noise_level", 1, '(1|2)') -cmd:option("-category", "anime_style_art", '(anime_style_art|photo)') +cmd:option("-style", "art", '(art|photo)') cmd:option("-color", 'rgb', '(y|rgb)') -cmd:option("-color_noise", 0, 'enable data augmentation using color noise (1|0)') -cmd:option("-overlay", 0, 'enable data augmentation using overlay (1|0)') -cmd:option("-scale", 2.0, 'scale') +cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)') +cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)') +cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') -cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)') -cmd:option("-crop_size", 128, 'crop size') -cmd:option("-max_size", 512, 'crop if image size larger then this value.') -cmd:option("-batch_size", 2, 'mini batch size') -cmd:option("-epoch", 200, 'epoch') +cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)') +cmd:option("-crop_size", 46, 'crop size') +cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly') +cmd:option("-batch_size", 8, 'mini batch size') +cmd:option("-epoch", 200, 'number of total epochs to run') cmd:option("-thread", -1, 'number of CPU threads') -cmd:option("-jpeg_sampling_factors", 444, '(444|422)') -cmd:option("-validation_ratio", 0.1, 'validation ratio') -cmd:option("-validation_crops", 40, 'number of crop region in validation') +cmd:option("-jpeg_sampling_factors", 444, '(444|420)') +cmd:option("-validation_rate", 0.05, 'validation-set rate of data') +cmd:option("-validation_crops", 80, 'number of region per image in validation') cmd:option("-active_cropping_rate", 0.5, 'active cropping rate') -cmd:option("-active_cropping_tries", 20, 'active cropping tries') +cmd:option("-active_cropping_tries", 10, 'active cropping tries') local opt = cmd:parse(arg) for k, v in pairs(opt) do @@ -64,9 +64,9 @@ end if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then error("scale must be mod-2") end -if not (settings.category == "anime_style_art" or - settings.category == "photo") then - error(string.format("unknown category: %s", settings.category)) +if not (settings.style == "art" or + settings.style == "photo") then + error(string.format("unknown style: %s", settings.style)) end if settings.random_half == 1 then settings.random_half = true diff --git a/train.lua b/train.lua index b315c6e..e614da4 100644 --- a/train.lua +++ b/train.lua @@ -87,7 +87,7 @@ local function transformer(x, is_validation, n, offset) if is_validation == nil then is_validation = false end local color_noise = nil local overlay = nil - local active_cropping_ratio = nil + local active_cropping_rate = nil local active_cropping_tries = nil if is_validation then @@ -117,7 +117,7 @@ local function transformer(x, is_validation, n, offset) }) elseif settings.method == "noise" then return pairwise_transform.jpeg(x, - settings.category, + settings.style, settings.noise_level, settings.crop_size, offset, n, @@ -142,7 +142,7 @@ local function train() local criterion = create_criterion(model) local x = torch.load(settings.images) local lrd_count = 0 - local train_x, valid_x = split_data(x, math.floor(settings.validation_ratio * #x)) + local train_x, valid_x = split_data(x, math.floor(settings.validation_rate * #x)) local adam_config = { learningRate = settings.learning_rate, xBatchSize = settings.batch_size, diff --git a/train.sh b/train.sh index 22fdc09..79804ea 100755 --- a/train.sh +++ b/train.sh @@ -2,12 +2,11 @@ th convert_data.lua -th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 1 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80 -th cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii +th train.lua -method scale -model_dir models/anime_style_art_rgb -test images/miku_small.png -thread 4 +th tools/cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii -th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method noise -noise_level 2 -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_noisy.jpg -validation_ratio 0.1 -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_crops 80 -th cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii - -th train.lua -color rgb -random_half 1 -jpeg_sampling_factors 444 -color_noise 0 -overlay 0 -epoch 200 -method scale -crop_size 46 -batch_size 8 -model_dir models/anime_style_art_rgb -test images/miku_small_noisy.jpg -active_cropping_rate 0.5 -active_cropping_tries 10 -validation_ratio 0.1 -validation_crops 80 -th cleanup_model.lua -model models/anime_style_art_rgb/scale2.0x_model.t7 -oformat ascii +th train.lua -method noise -noise_level 1 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4 +th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise1_model.t7 -oformat ascii +th train.lua -method noise -noise_level 2 -style art -model_dir models/anime_style_art_rgb -test images/miku_noisy.png -thread 4 +th tools/cleanup_model.lua -model models/anime_style_art_rgb/noise2_model.t7 -oformat ascii diff --git a/train_ukbench.sh b/train_ukbench.sh index eda0ec1..e72c526 100755 --- a/train_ukbench.sh +++ b/train_ukbench.sh @@ -1,11 +1,6 @@ #!/bin/sh -th train.lua -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 1 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg -th cleanup_model.lua -model models/ukbench2/noise1_model.t7 -oformat ascii - -th train.lua -core 1 -category photo -color rgb -color_noise 0 -overlay 0 -random_half 0 -epoch 300 -batch_size 1 -method noise -noise_level 2 -data_dir ukbench -model_dir models/ukbench2 -test photo2.jpg -th cleanup_model.lua -model models/ukbench2/noise2_model.t7 -oformat ascii - -th train.lua -category photo -color rgb -random_half 0 -epoch 400 -batch_size 1 -method scale -scale 2 -model_dir models/ukbench2 -data_dir ukbench -test photo2-noise.png -th cleanup_model.lua -model models/ukbench2/scale2.0x_model.t7 -oformat ascii +th convert_data.lua -data_dir ./data/ukbench +th train.lua -method scale -data_dir ./data/ukbench -model_dir models/ukbench -test images/lena.jpg -thread 4 +th tools/cleanup_model.lua -model models/ukbench/scale2.0x_model.t7 -oformat ascii diff --git a/waifu2x.lua b/waifu2x.lua index 4a29dae..72814ac 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -128,11 +128,11 @@ local function waifu2x() cmd:text() cmd:text("waifu2x") cmd:text("Options:") - cmd:option("-i", "images/miku_small.png", 'path of the input image') - cmd:option("-l", "", 'path of the image-list') + cmd:option("-i", "images/miku_small.png", 'path to input image') + cmd:option("-l", "", 'path to image-list.txt') cmd:option("-scale", 2, 'scale factor') - cmd:option("-o", "(auto)", 'path of the output file') - cmd:option("-model_dir", "./models/anime_style_art_rgb", 'model directory') + cmd:option("-o", "(auto)", 'path to output file') + cmd:option("-model_dir", "./models/anime_style_art_rgb", 'path to model directory') cmd:option("-m", "noise_scale", 'method (noise|scale|noise_scale)') cmd:option("-noise_level", 1, '(1|2)') cmd:option("-crop_size", 128, 'patch size per process')