1
0
Fork 0
mirror of synced 2024-06-28 03:00:54 +12:00

tunable parameters

This commit is contained in:
nagadomi 2015-11-07 07:18:22 +09:00
parent cf862782a5
commit 3ea16b3b86
4 changed files with 69 additions and 86 deletions

View file

@ -11,25 +11,44 @@ local function pcacov(x)
local ce, cv = torch.symeig(c, 'V') local ce, cv = torch.symeig(c, 'V')
return ce, cv return ce, cv
end end
function data_augmentation.color_noise(src, factor) function data_augmentation.color_noise(src, p, factor)
factor = factor or 0.1 factor = factor or 0.1
local src, conversion = iproc.byte2float(src) if torch.uniform() < p then
local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous() local src, conversion = iproc.byte2float(src)
local ce, cv = pcacov(src_t) local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor) local ce, cv = pcacov(src_t)
local color_scale = torch.Tensor(3):uniform(1 / (1 + factor), 1 + factor)
pca_space = torch.mm(src_t, cv):t():contiguous()
for i = 1, 3 do pca_space = torch.mm(src_t, cv):t():contiguous()
pca_space[i]:mul(color_scale[i]) for i = 1, 3 do
end pca_space[i]:mul(color_scale[i])
local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src) end
dest[torch.lt(dest, 0.0)] = 0.0 local dest = torch.mm(pca_space:t(), cv:t()):t():contiguous():resizeAs(src)
dest[torch.gt(dest, 1.0)] = 1.0 dest[torch.lt(dest, 0.0)] = 0.0
dest[torch.gt(dest, 1.0)] = 1.0
if conversion then if conversion then
dest = iproc.float2byte(dest) dest = iproc.float2byte(dest)
end
return dest
else
return src
end
end
function data_augmentation.overlay(src, p)
if torch.uniform() < p then
local r = torch.uniform()
local src, conversion = iproc.byte2float(src)
src = src:contiguous()
local flip = data_augmentation.flip(src)
flip:mul(r):add(src * (1.0 - r))
if conversion then
flip = iproc.float2byte(flip)
end
return flip
else
return src
end end
return dest
end end
function data_augmentation.shift_1px(src) function data_augmentation.shift_1px(src)
-- reducing the even/odd issue in nearest neighbor scaler. -- reducing the even/odd issue in nearest neighbor scaler.
@ -76,20 +95,4 @@ function data_augmentation.flip(src)
end end
return dest return dest
end end
function data_augmentation.overlay(src, p)
p = p or 0.25
if torch.uniform() < p then
local r = torch.uniform(0.2, 0.8)
local src, conversion = iproc.byte2float(src)
src = src:contiguous()
local flip = data_augmentation.flip(src)
flip:mul(r):add(src * (1.0 - r))
if conversion then
flip = iproc.float2byte(flip)
end
return flip
else
return src
end
end
return data_augmentation return data_augmentation

View file

@ -6,9 +6,8 @@ local data_augmentation = require 'data_augmentation'
local pairwise_transform = {} local pairwise_transform = {}
local function random_half(src, p) local function random_half(src, p)
p = p or 0.25 if torch.uniform() < p then
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)] local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter) return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
else else
return src return src
@ -34,17 +33,11 @@ local function crop_if_large(src, max_size)
end end
local function preprocess(src, crop_size, options) local function preprocess(src, crop_size, options)
local dest = src local dest = src
if options.random_half then dest = random_half(dest, options.random_half_rate)
dest = random_half(dest)
end
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size)) dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest) dest = data_augmentation.flip(dest)
if options.color_noise then dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
dest = data_augmentation.color_noise(dest) dest = data_augmentation.overlay(dest, options.random_overlay_rate)
end
if options.overlay then
dest = data_augmentation.overlay(dest)
end
dest = data_augmentation.shift_1px(dest) dest = data_augmentation.shift_1px(dest)
return dest return dest

View file

@ -26,19 +26,19 @@ cmd:option("-method", "scale", 'method to training (noise|scale)')
cmd:option("-noise_level", 1, '(1|2)') cmd:option("-noise_level", 1, '(1|2)')
cmd:option("-style", "art", '(art|photo)') cmd:option("-style", "art", '(art|photo)')
cmd:option("-color", 'rgb', '(y|rgb)') cmd:option("-color", 'rgb', '(y|rgb)')
cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)') cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)') cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
cmd:option("-scale", 2.0, 'scale factor (2)') cmd:option("-scale", 2.0, 'scale factor (2)')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
cmd:option("-crop_size", 46, 'crop size') cmd:option("-crop_size", 46, 'crop size')
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly') cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
cmd:option("-batch_size", 8, 'mini batch size') cmd:option("-batch_size", 8, 'mini batch size')
cmd:option("-epoch", 200, 'number of total epochs to run') cmd:option("-epoch", 200, 'number of total epochs to run')
cmd:option("-thread", -1, 'number of CPU threads') cmd:option("-thread", -1, 'number of CPU threads')
cmd:option("-jpeg_sampling_factors", 444, '(444|420)') cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
cmd:option("-validation_rate", 0.05, 'validation-set rate of data') cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
cmd:option("-validation_crops", 80, 'number of region per image in validation') cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate') cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
cmd:option("-active_cropping_tries", 10, 'active cropping tries') cmd:option("-active_cropping_tries", 10, 'active cropping tries')
cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)') cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
@ -69,21 +69,6 @@ if not (settings.style == "art" or
settings.style == "photo") then settings.style == "photo") then
error(string.format("unknown style: %s", settings.style)) error(string.format("unknown style: %s", settings.style))
end end
if settings.random_half == 1 then
settings.random_half = true
else
settings.random_half = false
end
if settings.color_noise == 1 then
settings.color_noise = true
else
settings.color_noise = false
end
if settings.overlay == 1 then
settings.overlay = true
else
settings.overlay = false
end
if settings.thread > 0 then if settings.thread > 0 then
torch.setnumthreads(tonumber(settings.thread)) torch.setnumthreads(tonumber(settings.thread))

View file

@ -85,20 +85,20 @@ local function transformer(x, is_validation, n, offset)
x = compression.decompress(x) x = compression.decompress(x)
n = n or settings.batch_size; n = n or settings.batch_size;
if is_validation == nil then is_validation = false end if is_validation == nil then is_validation = false end
local color_noise = nil local random_color_noise_rate = nil
local overlay = nil local random_overlay_rate = nil
local active_cropping_rate = nil local active_cropping_rate = nil
local active_cropping_tries = nil local active_cropping_tries = nil
if is_validation then if is_validation then
active_cropping_rate = 0 active_cropping_rate = 0
active_cropping_tries = 0 active_cropping_tries = 0
color_noise = false random_color_noise_rate = 0.0
overlay = false random_overlay_rate = 0.0
else else
active_cropping_rate = settings.active_cropping_rate active_cropping_rate = settings.active_cropping_rate
active_cropping_tries = settings.active_cropping_tries active_cropping_tries = settings.active_cropping_tries
color_noise = settings.color_noise random_color_noise_rate = settings.random_color_noise_rate
overlay = settings.overlay random_overlay_rate = settings.random_overlay_rate
end end
if settings.method == "scale" then if settings.method == "scale" then
@ -106,13 +106,14 @@ local function transformer(x, is_validation, n, offset)
settings.scale, settings.scale,
settings.crop_size, offset, settings.crop_size, offset,
n, n,
{ color_noise = color_noise, {
overlay = overlay, random_half_rate = settings.random_half_rate,
random_half = settings.random_half, random_color_noise_rate = random_color_noise_rate,
max_size = settings.max_size, random_overlay_rate = random_overlay_rate,
active_cropping_rate = active_cropping_rate, max_size = settings.max_size,
active_cropping_tries = active_cropping_tries, active_cropping_rate = active_cropping_rate,
rgb = (settings.color == "rgb") active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb")
}) })
elseif settings.method == "noise" then elseif settings.method == "noise" then
return pairwise_transform.jpeg(x, return pairwise_transform.jpeg(x,
@ -120,15 +121,16 @@ local function transformer(x, is_validation, n, offset)
settings.noise_level, settings.noise_level,
settings.crop_size, offset, settings.crop_size, offset,
n, n,
{ color_noise = color_noise, {
overlay = overlay, random_half_rate = settings.random_half_rate,
random_half = settings.random_half, random_color_noise_rate = random_color_noise_rate,
max_size = settings.max_size, random_overlay_rate = random_overlay_rate,
jpeg_sampling_factors = settings.jpeg_sampling_factors, max_size = settings.max_size,
active_cropping_rate = active_cropping_rate, jpeg_sampling_factors = settings.jpeg_sampling_factors,
active_cropping_tries = active_cropping_tries, active_cropping_rate = active_cropping_rate,
nr_rate = settings.nr_rate, active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb") nr_rate = settings.nr_rate,
rgb = (settings.color == "rgb")
}) })
end end
end end