diff --git a/lib/data_augmentation.lua b/lib/data_augmentation.lua index ab6aa0e..d31227a 100644 --- a/lib/data_augmentation.lua +++ b/lib/data_augmentation.lua @@ -11,25 +11,44 @@ local function pcacov(x) local ce, cv = torch.symeig(c, 'V') return ce, cv end -function data_augmentation.color_noise(src, factor) +function data_augmentation.color_noise(src, p, factor) factor = factor or 0.1 - local src, conversion = iproc.byte2float(src) - local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous() - local ce, cv = pcacov(src_t) - local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor) - - pca_space = torch.mm(src_t, cv):t():contiguous() - for i = 1, 3 do - pca_space[i]:mul(color_scale[i]) - end - local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src) - dest[torch.lt(dest, 0.0)] = 0.0 - dest[torch.gt(dest, 1.0)] = 1.0 + if torch.uniform() < p then + local src, conversion = iproc.byte2float(src) + local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous() + local ce, cv = pcacov(src_t) + local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor) + + pca_space = torch.mm(src_t, cv):t():contiguous() + for i = 1, 3 do + pca_space[i]:mul(color_scale[i]) + end + local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src) + dest[torch.lt(dest, 0.0)] = 0.0 + dest[torch.gt(dest, 1.0)] = 1.0 - if conversion then - dest = iproc.float2byte(dest) + if conversion then + dest = iproc.float2byte(dest) + end + return dest + else + return src + end +end +function data_augmentation.overlay(src, p) + if torch.uniform() < p then + local r = torch.uniform() + local src, conversion = iproc.byte2float(src) + src = src:contiguous() + local flip = data_augmentation.flip(src) + flip:mul(r):add(src * (1.0 - r)) + if conversion then + flip = iproc.float2byte(flip) + end + return flip + else + return src end - return dest end function data_augmentation.shift_1px(src) -- reducing the even/odd issue in nearest neighbor scaler. @@ -76,20 +95,4 @@ function data_augmentation.flip(src) end return dest end -function data_augmentation.overlay(src, p) - p = p or 0.25 - if torch.uniform() < p then - local r = torch.uniform(0.2, 0.8) - local src, conversion = iproc.byte2float(src) - src = src:contiguous() - local flip = data_augmentation.flip(src) - flip:mul(r):add(src * (1.0 - r)) - if conversion then - flip = iproc.float2byte(flip) - end - return flip - else - return src - end -end return data_augmentation diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index abf649f..b7509d3 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -6,9 +6,8 @@ local data_augmentation = require 'data_augmentation' local pairwise_transform = {} local function random_half(src, p) - p = p or 0.25 - local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)] - if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then + if torch.uniform() < p then + local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)] return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) else return src @@ -34,17 +33,11 @@ local function crop_if_large(src, max_size) end local function preprocess(src, crop_size, options) local dest = src - if options.random_half then - dest = random_half(dest) - end + dest = random_half(dest, options.random_half_rate) dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size)) dest = data_augmentation.flip(dest) - if options.color_noise then - dest = data_augmentation.color_noise(dest) - end - if options.overlay then - dest = data_augmentation.overlay(dest) - end + dest = data_augmentation.color_noise(dest, options.random_color_noise_rate) + dest = data_augmentation.overlay(dest, options.random_overlay_rate) dest = data_augmentation.shift_1px(dest) return dest diff --git a/lib/settings.lua b/lib/settings.lua index 222f7d0..f976ba6 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -26,19 +26,19 @@ cmd:option("-method", "scale", 'method to training (noise|scale)') cmd:option("-noise_level", 1, '(1|2)') cmd:option("-style", "art", '(art|photo)') cmd:option("-color", 'rgb', '(y|rgb)') -cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)') -cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)') +cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)') +cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)') +cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)') cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') -cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)') cmd:option("-crop_size", 46, 'crop size') cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly') cmd:option("-batch_size", 8, 'mini batch size') cmd:option("-epoch", 200, 'number of total epochs to run') cmd:option("-thread", -1, 'number of CPU threads') cmd:option("-jpeg_sampling_factors", 444, '(444|420)') -cmd:option("-validation_rate", 0.05, 'validation-set rate of data') -cmd:option("-validation_crops", 80, 'number of region per image in validation') +cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)') +cmd:option("-validation_crops", 80, 'number of cropping region per image in validation') cmd:option("-active_cropping_rate", 0.5, 'active cropping rate') cmd:option("-active_cropping_tries", 10, 'active cropping tries') cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)') @@ -69,21 +69,6 @@ if not (settings.style == "art" or settings.style == "photo") then error(string.format("unknown style: %s", settings.style)) end -if settings.random_half == 1 then - settings.random_half = true -else - settings.random_half = false -end -if settings.color_noise == 1 then - settings.color_noise = true -else - settings.color_noise = false -end -if settings.overlay == 1 then - settings.overlay = true -else - settings.overlay = false -end if settings.thread > 0 then torch.setnumthreads(tonumber(settings.thread)) diff --git a/train.lua b/train.lua index 1fd4325..7139807 100644 --- a/train.lua +++ b/train.lua @@ -85,20 +85,20 @@ local function transformer(x, is_validation, n, offset) x = compression.decompress(x) n = n or settings.batch_size; if is_validation == nil then is_validation = false end - local color_noise = nil - local overlay = nil + local random_color_noise_rate = nil + local random_overlay_rate = nil local active_cropping_rate = nil local active_cropping_tries = nil if is_validation then active_cropping_rate = 0 active_cropping_tries = 0 - color_noise = false - overlay = false + random_color_noise_rate = 0.0 + random_overlay_rate = 0.0 else active_cropping_rate = settings.active_cropping_rate active_cropping_tries = settings.active_cropping_tries - color_noise = settings.color_noise - overlay = settings.overlay + random_color_noise_rate = settings.random_color_noise_rate + random_overlay_rate = settings.random_overlay_rate end if settings.method == "scale" then @@ -106,13 +106,14 @@ local function transformer(x, is_validation, n, offset) settings.scale, settings.crop_size, offset, n, - { color_noise = color_noise, - overlay = overlay, - random_half = settings.random_half, - max_size = settings.max_size, - active_cropping_rate = active_cropping_rate, - active_cropping_tries = active_cropping_tries, - rgb = (settings.color == "rgb") + { + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + max_size = settings.max_size, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + rgb = (settings.color == "rgb") }) elseif settings.method == "noise" then return pairwise_transform.jpeg(x, @@ -120,15 +121,16 @@ local function transformer(x, is_validation, n, offset) settings.noise_level, settings.crop_size, offset, n, - { color_noise = color_noise, - overlay = overlay, - random_half = settings.random_half, - max_size = settings.max_size, - jpeg_sampling_factors = settings.jpeg_sampling_factors, - active_cropping_rate = active_cropping_rate, - active_cropping_tries = active_cropping_tries, - nr_rate = settings.nr_rate, - rgb = (settings.color == "rgb") + { + random_half_rate = settings.random_half_rate, + random_color_noise_rate = random_color_noise_rate, + random_overlay_rate = random_overlay_rate, + max_size = settings.max_size, + jpeg_sampling_factors = settings.jpeg_sampling_factors, + active_cropping_rate = active_cropping_rate, + active_cropping_tries = active_cropping_tries, + nr_rate = settings.nr_rate, + rgb = (settings.color == "rgb") }) end end