tunable parameters
This commit is contained in:
parent
cf862782a5
commit
3ea16b3b86
4 changed files with 69 additions and 86 deletions
|
@ -11,8 +11,9 @@ local function pcacov(x)
|
||||||
local ce, cv = torch.symeig(c, 'V')
|
local ce, cv = torch.symeig(c, 'V')
|
||||||
return ce, cv
|
return ce, cv
|
||||||
end
|
end
|
||||||
function data_augmentation.color_noise(src, factor)
|
function data_augmentation.color_noise(src, p, factor)
|
||||||
factor = factor or 0.1
|
factor = factor or 0.1
|
||||||
|
if torch.uniform() < p then
|
||||||
local src, conversion = iproc.byte2float(src)
|
local src, conversion = iproc.byte2float(src)
|
||||||
local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
|
local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
|
||||||
local ce, cv = pcacov(src_t)
|
local ce, cv = pcacov(src_t)
|
||||||
|
@ -30,6 +31,24 @@ function data_augmentation.color_noise(src, factor)
|
||||||
dest = iproc.float2byte(dest)
|
dest = iproc.float2byte(dest)
|
||||||
end
|
end
|
||||||
return dest
|
return dest
|
||||||
|
else
|
||||||
|
return src
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function data_augmentation.overlay(src, p)
|
||||||
|
if torch.uniform() < p then
|
||||||
|
local r = torch.uniform()
|
||||||
|
local src, conversion = iproc.byte2float(src)
|
||||||
|
src = src:contiguous()
|
||||||
|
local flip = data_augmentation.flip(src)
|
||||||
|
flip:mul(r):add(src * (1.0 - r))
|
||||||
|
if conversion then
|
||||||
|
flip = iproc.float2byte(flip)
|
||||||
|
end
|
||||||
|
return flip
|
||||||
|
else
|
||||||
|
return src
|
||||||
|
end
|
||||||
end
|
end
|
||||||
function data_augmentation.shift_1px(src)
|
function data_augmentation.shift_1px(src)
|
||||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||||
|
@ -76,20 +95,4 @@ function data_augmentation.flip(src)
|
||||||
end
|
end
|
||||||
return dest
|
return dest
|
||||||
end
|
end
|
||||||
function data_augmentation.overlay(src, p)
|
|
||||||
p = p or 0.25
|
|
||||||
if torch.uniform() < p then
|
|
||||||
local r = torch.uniform(0.2, 0.8)
|
|
||||||
local src, conversion = iproc.byte2float(src)
|
|
||||||
src = src:contiguous()
|
|
||||||
local flip = data_augmentation.flip(src)
|
|
||||||
flip:mul(r):add(src * (1.0 - r))
|
|
||||||
if conversion then
|
|
||||||
flip = iproc.float2byte(flip)
|
|
||||||
end
|
|
||||||
return flip
|
|
||||||
else
|
|
||||||
return src
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return data_augmentation
|
return data_augmentation
|
||||||
|
|
|
@ -6,9 +6,8 @@ local data_augmentation = require 'data_augmentation'
|
||||||
local pairwise_transform = {}
|
local pairwise_transform = {}
|
||||||
|
|
||||||
local function random_half(src, p)
|
local function random_half(src, p)
|
||||||
p = p or 0.25
|
if torch.uniform() < p then
|
||||||
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
|
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
|
||||||
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
|
|
||||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||||
else
|
else
|
||||||
return src
|
return src
|
||||||
|
@ -34,17 +33,11 @@ local function crop_if_large(src, max_size)
|
||||||
end
|
end
|
||||||
local function preprocess(src, crop_size, options)
|
local function preprocess(src, crop_size, options)
|
||||||
local dest = src
|
local dest = src
|
||||||
if options.random_half then
|
dest = random_half(dest, options.random_half_rate)
|
||||||
dest = random_half(dest)
|
|
||||||
end
|
|
||||||
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||||
dest = data_augmentation.flip(dest)
|
dest = data_augmentation.flip(dest)
|
||||||
if options.color_noise then
|
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||||
dest = data_augmentation.color_noise(dest)
|
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||||
end
|
|
||||||
if options.overlay then
|
|
||||||
dest = data_augmentation.overlay(dest)
|
|
||||||
end
|
|
||||||
dest = data_augmentation.shift_1px(dest)
|
dest = data_augmentation.shift_1px(dest)
|
||||||
|
|
||||||
return dest
|
return dest
|
||||||
|
|
|
@ -26,19 +26,19 @@ cmd:option("-method", "scale", 'method to training (noise|scale)')
|
||||||
cmd:option("-noise_level", 1, '(1|2)')
|
cmd:option("-noise_level", 1, '(1|2)')
|
||||||
cmd:option("-style", "art", '(art|photo)')
|
cmd:option("-style", "art", '(art|photo)')
|
||||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||||
cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
|
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
|
||||||
cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
|
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
||||||
|
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
||||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||||
cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
|
|
||||||
cmd:option("-crop_size", 46, 'crop size')
|
cmd:option("-crop_size", 46, 'crop size')
|
||||||
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
||||||
cmd:option("-batch_size", 8, 'mini batch size')
|
cmd:option("-batch_size", 8, 'mini batch size')
|
||||||
cmd:option("-epoch", 200, 'number of total epochs to run')
|
cmd:option("-epoch", 200, 'number of total epochs to run')
|
||||||
cmd:option("-thread", -1, 'number of CPU threads')
|
cmd:option("-thread", -1, 'number of CPU threads')
|
||||||
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
||||||
cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
|
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
|
||||||
cmd:option("-validation_crops", 80, 'number of region per image in validation')
|
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
|
||||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||||
cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||||
|
@ -69,21 +69,6 @@ if not (settings.style == "art" or
|
||||||
settings.style == "photo") then
|
settings.style == "photo") then
|
||||||
error(string.format("unknown style: %s", settings.style))
|
error(string.format("unknown style: %s", settings.style))
|
||||||
end
|
end
|
||||||
if settings.random_half == 1 then
|
|
||||||
settings.random_half = true
|
|
||||||
else
|
|
||||||
settings.random_half = false
|
|
||||||
end
|
|
||||||
if settings.color_noise == 1 then
|
|
||||||
settings.color_noise = true
|
|
||||||
else
|
|
||||||
settings.color_noise = false
|
|
||||||
end
|
|
||||||
if settings.overlay == 1 then
|
|
||||||
settings.overlay = true
|
|
||||||
else
|
|
||||||
settings.overlay = false
|
|
||||||
end
|
|
||||||
|
|
||||||
if settings.thread > 0 then
|
if settings.thread > 0 then
|
||||||
torch.setnumthreads(tonumber(settings.thread))
|
torch.setnumthreads(tonumber(settings.thread))
|
||||||
|
|
26
train.lua
26
train.lua
|
@ -85,20 +85,20 @@ local function transformer(x, is_validation, n, offset)
|
||||||
x = compression.decompress(x)
|
x = compression.decompress(x)
|
||||||
n = n or settings.batch_size;
|
n = n or settings.batch_size;
|
||||||
if is_validation == nil then is_validation = false end
|
if is_validation == nil then is_validation = false end
|
||||||
local color_noise = nil
|
local random_color_noise_rate = nil
|
||||||
local overlay = nil
|
local random_overlay_rate = nil
|
||||||
local active_cropping_rate = nil
|
local active_cropping_rate = nil
|
||||||
local active_cropping_tries = nil
|
local active_cropping_tries = nil
|
||||||
if is_validation then
|
if is_validation then
|
||||||
active_cropping_rate = 0
|
active_cropping_rate = 0
|
||||||
active_cropping_tries = 0
|
active_cropping_tries = 0
|
||||||
color_noise = false
|
random_color_noise_rate = 0.0
|
||||||
overlay = false
|
random_overlay_rate = 0.0
|
||||||
else
|
else
|
||||||
active_cropping_rate = settings.active_cropping_rate
|
active_cropping_rate = settings.active_cropping_rate
|
||||||
active_cropping_tries = settings.active_cropping_tries
|
active_cropping_tries = settings.active_cropping_tries
|
||||||
color_noise = settings.color_noise
|
random_color_noise_rate = settings.random_color_noise_rate
|
||||||
overlay = settings.overlay
|
random_overlay_rate = settings.random_overlay_rate
|
||||||
end
|
end
|
||||||
|
|
||||||
if settings.method == "scale" then
|
if settings.method == "scale" then
|
||||||
|
@ -106,9 +106,10 @@ local function transformer(x, is_validation, n, offset)
|
||||||
settings.scale,
|
settings.scale,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
n,
|
n,
|
||||||
{ color_noise = color_noise,
|
{
|
||||||
overlay = overlay,
|
random_half_rate = settings.random_half_rate,
|
||||||
random_half = settings.random_half,
|
random_color_noise_rate = random_color_noise_rate,
|
||||||
|
random_overlay_rate = random_overlay_rate,
|
||||||
max_size = settings.max_size,
|
max_size = settings.max_size,
|
||||||
active_cropping_rate = active_cropping_rate,
|
active_cropping_rate = active_cropping_rate,
|
||||||
active_cropping_tries = active_cropping_tries,
|
active_cropping_tries = active_cropping_tries,
|
||||||
|
@ -120,9 +121,10 @@ local function transformer(x, is_validation, n, offset)
|
||||||
settings.noise_level,
|
settings.noise_level,
|
||||||
settings.crop_size, offset,
|
settings.crop_size, offset,
|
||||||
n,
|
n,
|
||||||
{ color_noise = color_noise,
|
{
|
||||||
overlay = overlay,
|
random_half_rate = settings.random_half_rate,
|
||||||
random_half = settings.random_half,
|
random_color_noise_rate = random_color_noise_rate,
|
||||||
|
random_overlay_rate = random_overlay_rate,
|
||||||
max_size = settings.max_size,
|
max_size = settings.max_size,
|
||||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||||
active_cropping_rate = active_cropping_rate,
|
active_cropping_rate = active_cropping_rate,
|
||||||
|
|
Loading…
Reference in a new issue