2015-05-16 17:48:05 +12:00
|
|
|
require 'xlua'
|
|
|
|
require 'pl'
|
2015-10-26 13:23:52 +13:00
|
|
|
require 'trepl'
|
2015-05-16 17:48:05 +12:00
|
|
|
|
|
|
|
-- global settings
|
|
|
|
|
|
|
|
if package.preload.settings then
|
|
|
|
return package.preload.settings
|
|
|
|
end
|
|
|
|
|
|
|
|
-- default tensor type
|
|
|
|
torch.setdefaulttensortype('torch.FloatTensor')
|
|
|
|
|
|
|
|
local settings = {}
|
|
|
|
|
|
|
|
local cmd = torch.CmdLine()
|
|
|
|
cmd:text()
|
2015-10-28 20:25:41 +13:00
|
|
|
cmd:text("waifu2x-training")
|
2015-05-16 17:48:05 +12:00
|
|
|
cmd:text("Options:")
|
2015-11-13 23:26:58 +13:00
|
|
|
cmd:option("-gpu", -1, 'GPU Device ID')
|
2015-11-06 14:08:54 +13:00
|
|
|
cmd:option("-seed", 11, 'RNG seed')
|
|
|
|
cmd:option("-data_dir", "./data", 'path to data directory')
|
2015-10-31 08:38:28 +13:00
|
|
|
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
2015-11-06 14:08:54 +13:00
|
|
|
cmd:option("-test", "images/miku_small.png", 'path to test image')
|
2015-05-16 17:48:05 +12:00
|
|
|
cmd:option("-model_dir", "./models", 'model directory')
|
2016-06-08 09:39:36 +12:00
|
|
|
cmd:option("-method", "scale", 'method to training (noise|scale|noise_scale)')
|
2016-05-13 12:49:53 +12:00
|
|
|
cmd:option("-model", "vgg_7", 'model architecture (vgg_7|vgg_12|upconv_7|upconv_8_4x|dilated_7)')
|
2016-03-18 01:21:18 +13:00
|
|
|
cmd:option("-noise_level", 1, '(1|2|3)')
|
2015-11-06 14:08:54 +13:00
|
|
|
cmd:option("-style", "art", '(art|photo)')
|
2015-06-23 05:27:28 +12:00
|
|
|
cmd:option("-color", 'rgb', '(y|rgb)')
|
2015-11-07 11:18:22 +13:00
|
|
|
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
|
|
|
|
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
|
|
|
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
2015-11-27 22:36:36 +13:00
|
|
|
cmd:option("-random_unsharp_mask_rate", 0.0, 'data augmentation using unsharp mask (0.0-1.0)')
|
2015-11-06 14:08:54 +13:00
|
|
|
cmd:option("-scale", 2.0, 'scale factor (2)')
|
2016-06-08 09:39:36 +12:00
|
|
|
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
2016-05-13 12:49:53 +12:00
|
|
|
cmd:option("-crop_size", 48, 'crop size')
|
2016-04-11 02:06:39 +12:00
|
|
|
cmd:option("-max_size", 256, 'if image is larger than N, image will be crop randomly')
|
2016-06-08 09:39:36 +12:00
|
|
|
cmd:option("-batch_size", 16, 'mini batch size')
|
|
|
|
cmd:option("-patches", 64, 'number of patch samples')
|
|
|
|
cmd:option("-inner_epoch", 1, 'number of inner epochs')
|
|
|
|
cmd:option("-epoch", 100, 'number of epochs to run')
|
2015-10-28 19:30:47 +13:00
|
|
|
cmd:option("-thread", -1, 'number of CPU threads')
|
2015-11-26 21:10:57 +13:00
|
|
|
cmd:option("-jpeg_chroma_subsampling_rate", 0.0, 'the rate of YUV 4:2:0/YUV 4:4:4 in denoising training (0.0-1.0)')
|
2015-11-07 11:18:22 +13:00
|
|
|
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
|
2016-03-22 14:19:52 +13:00
|
|
|
cmd:option("-validation_crops", 160, 'number of cropping region per image in validation')
|
2015-10-26 13:23:52 +13:00
|
|
|
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
2015-11-06 14:08:54 +13:00
|
|
|
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
2015-11-11 02:07:45 +13:00
|
|
|
cmd:option("-nr_rate", 0.75, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
2015-12-04 22:49:34 +13:00
|
|
|
cmd:option("-save_history", 0, 'save all model (0|1)')
|
2016-03-14 09:06:14 +13:00
|
|
|
cmd:option("-plot", 0, 'plot loss chart(0|1)')
|
2016-05-15 14:33:34 +12:00
|
|
|
cmd:option("-downsampling_filters", "Box,Lanczos,Sinc", '(comma separated)downsampling filters for 2x scale training. (Point,Box,Triangle,Hermite,Hanning,Hamming,Blackman,Gaussian,Quadratic,Cubic,Catrom,Mitchell,Lanczos,Bessel,Sinc)')
|
2016-03-28 23:07:09 +13:00
|
|
|
cmd:option("-gamma_correction", 0, 'Resizing with colorspace correction(sRGB:gamma 2.2) in scale training (0|1)')
|
2016-04-03 02:03:27 +13:00
|
|
|
cmd:option("-upsampling_filter", "Box", 'upsampling filter for 2x scale training (dev)')
|
2016-04-11 02:06:39 +12:00
|
|
|
cmd:option("-max_training_image_size", -1, 'if training image is larger than N, image will be crop randomly when data converting')
|
2016-05-20 02:02:02 +12:00
|
|
|
cmd:option("-use_transparent_png", 0, 'use transparent png (0|1)')
|
2016-06-08 09:39:36 +12:00
|
|
|
cmd:option("-resize_blur_min", 0.95, 'min blur parameter for ResizeImage')
|
2016-05-21 12:54:12 +12:00
|
|
|
cmd:option("-resize_blur_max", 1.05, 'max blur parameter for ResizeImage')
|
2016-06-08 09:39:36 +12:00
|
|
|
cmd:option("-oracle_rate", 0.1, '')
|
2016-05-27 19:49:42 +12:00
|
|
|
cmd:option("-oracle_drop_rate", 0.5, '')
|
2016-06-02 13:11:15 +12:00
|
|
|
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
2016-06-08 09:39:36 +12:00
|
|
|
cmd:option("-loss", "y", 'loss (rgb|y)')
|
2016-06-09 05:39:52 +12:00
|
|
|
cmd:option("-resume", "", 'resume model file')
|
2016-03-28 23:07:09 +13:00
|
|
|
|
|
|
|
local function to_bool(settings, name)
|
|
|
|
if settings[name] == 1 then
|
|
|
|
settings[name] = true
|
|
|
|
else
|
|
|
|
settings[name] = false
|
|
|
|
end
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
|
|
|
|
local opt = cmd:parse(arg)
|
|
|
|
for k, v in pairs(opt) do
|
|
|
|
settings[k] = v
|
|
|
|
end
|
2016-03-28 23:07:09 +13:00
|
|
|
to_bool(settings, "plot")
|
|
|
|
to_bool(settings, "save_history")
|
|
|
|
to_bool(settings, "gamma_correction")
|
2016-05-20 02:02:02 +12:00
|
|
|
to_bool(settings, "use_transparent_png")
|
2016-03-28 23:07:09 +13:00
|
|
|
|
|
|
|
if settings.plot then
|
2016-03-14 09:06:14 +13:00
|
|
|
require 'gnuplot'
|
2015-12-04 22:49:34 +13:00
|
|
|
end
|
|
|
|
if settings.save_history then
|
|
|
|
if settings.method == "noise" then
|
|
|
|
settings.model_file = string.format("%s/noise%d_model.%%d-%%d.t7",
|
|
|
|
settings.model_dir, settings.noise_level)
|
2016-06-06 17:04:13 +12:00
|
|
|
settings.model_file_best = string.format("%s/noise%d_model.t7",
|
|
|
|
settings.model_dir, settings.noise_level)
|
2015-12-04 22:49:34 +13:00
|
|
|
elseif settings.method == "scale" then
|
|
|
|
settings.model_file = string.format("%s/scale%.1fx_model.%%d-%%d.t7",
|
|
|
|
settings.model_dir, settings.scale)
|
2016-06-06 17:04:13 +12:00
|
|
|
settings.model_file_best = string.format("%s/scale%.1fx_model.t7",
|
|
|
|
settings.model_dir, settings.scale)
|
2016-06-08 09:39:36 +12:00
|
|
|
elseif settings.method == "noise_scale" then
|
|
|
|
settings.model_file = string.format("%s/noise%d_scale%.1fx_model.%%d-%%d.t7",
|
|
|
|
settings.model_dir,
|
|
|
|
settings.noise_level,
|
|
|
|
settings.scale)
|
|
|
|
settings.model_file_best = string.format("%s/noise%d_scale%.1fx_model.t7",
|
|
|
|
settings.model_dir,
|
|
|
|
settings.noise_level,
|
|
|
|
settings.scale)
|
2015-12-04 22:49:34 +13:00
|
|
|
else
|
|
|
|
error("unknown method: " .. settings.method)
|
|
|
|
end
|
|
|
|
else
|
|
|
|
if settings.method == "noise" then
|
|
|
|
settings.model_file = string.format("%s/noise%d_model.t7",
|
|
|
|
settings.model_dir, settings.noise_level)
|
|
|
|
elseif settings.method == "scale" then
|
|
|
|
settings.model_file = string.format("%s/scale%.1fx_model.t7",
|
|
|
|
settings.model_dir, settings.scale)
|
2016-06-08 09:39:36 +12:00
|
|
|
elseif settings.method == "noise_scale" then
|
|
|
|
settings.model_file = string.format("%s/noise%d_scale%.1fx_model.t7",
|
|
|
|
settings.model_dir, settings.noise_level, settings.scale)
|
2015-12-04 22:49:34 +13:00
|
|
|
else
|
|
|
|
error("unknown method: " .. settings.method)
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
end
|
2015-06-23 05:27:28 +12:00
|
|
|
if not (settings.color == "rgb" or settings.color == "y") then
|
|
|
|
error("color must be y or rgb")
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
if not (settings.scale == math.floor(settings.scale) and settings.scale % 2 == 0) then
|
|
|
|
error("scale must be mod-2")
|
|
|
|
end
|
2015-11-06 14:08:54 +13:00
|
|
|
if not (settings.style == "art" or
|
|
|
|
settings.style == "photo") then
|
|
|
|
error(string.format("unknown style: %s", settings.style))
|
2015-07-12 00:57:04 +12:00
|
|
|
end
|
2015-10-28 19:30:47 +13:00
|
|
|
if settings.thread > 0 then
|
|
|
|
torch.setnumthreads(tonumber(settings.thread))
|
|
|
|
end
|
2016-03-17 21:58:37 +13:00
|
|
|
if settings.downsampling_filters and settings.downsampling_filters:len() > 0 then
|
|
|
|
settings.downsampling_filters = settings.downsampling_filters:split(",")
|
|
|
|
else
|
|
|
|
settings.downsampling_filters = {"Box", "Lanczos", "Catrom"}
|
|
|
|
end
|
2015-05-16 17:48:05 +12:00
|
|
|
|
|
|
|
settings.images = string.format("%s/images.t7", settings.data_dir)
|
|
|
|
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
|
|
|
|
|
|
|
return settings
|