tunable parameters
This commit is contained in:
parent
cf862782a5
commit
3ea16b3b86
4 changed files with 69 additions and 86 deletions
|
@ -11,8 +11,9 @@ local function pcacov(x)
|
|||
local ce, cv = torch.symeig(c, 'V')
|
||||
return ce, cv
|
||||
end
|
||||
function data_augmentation.color_noise(src, factor)
|
||||
function data_augmentation.color_noise(src, p, factor)
|
||||
factor = factor or 0.1
|
||||
if torch.uniform() < p then
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
local src_t = src:reshape(src:size(1), src:nElement() / src:size(1)):t():contiguous()
|
||||
local ce, cv = pcacov(src_t)
|
||||
|
@ -30,6 +31,24 @@ function data_augmentation.color_noise(src, factor)
|
|||
dest = iproc.float2byte(dest)
|
||||
end
|
||||
return dest
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.overlay(src, p)
|
||||
if torch.uniform() < p then
|
||||
local r = torch.uniform()
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
src = src:contiguous()
|
||||
local flip = data_augmentation.flip(src)
|
||||
flip:mul(r):add(src * (1.0 - r))
|
||||
if conversion then
|
||||
flip = iproc.float2byte(flip)
|
||||
end
|
||||
return flip
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
function data_augmentation.shift_1px(src)
|
||||
-- reducing the even/odd issue in nearest neighbor scaler.
|
||||
|
@ -76,20 +95,4 @@ function data_augmentation.flip(src)
|
|||
end
|
||||
return dest
|
||||
end
|
||||
function data_augmentation.overlay(src, p)
|
||||
p = p or 0.25
|
||||
if torch.uniform() < p then
|
||||
local r = torch.uniform(0.2, 0.8)
|
||||
local src, conversion = iproc.byte2float(src)
|
||||
src = src:contiguous()
|
||||
local flip = data_augmentation.flip(src)
|
||||
flip:mul(r):add(src * (1.0 - r))
|
||||
if conversion then
|
||||
flip = iproc.float2byte(flip)
|
||||
end
|
||||
return flip
|
||||
else
|
||||
return src
|
||||
end
|
||||
end
|
||||
return data_augmentation
|
||||
|
|
|
@ -6,9 +6,8 @@ local data_augmentation = require 'data_augmentation'
|
|||
local pairwise_transform = {}
|
||||
|
||||
local function random_half(src, p)
|
||||
p = p or 0.25
|
||||
if torch.uniform() < p then
|
||||
local filter = ({"Box","Box","Blackman","SincFast","Jinc"})[torch.random(1, 5)]
|
||||
if p < torch.uniform() and (src:size(2) > 768 and src:size(3) > 1024) then
|
||||
return iproc.scale(src, src:size(3) * 0.5, src:size(2) * 0.5, filter)
|
||||
else
|
||||
return src
|
||||
|
@ -34,17 +33,11 @@ local function crop_if_large(src, max_size)
|
|||
end
|
||||
local function preprocess(src, crop_size, options)
|
||||
local dest = src
|
||||
if options.random_half then
|
||||
dest = random_half(dest)
|
||||
end
|
||||
dest = random_half(dest, options.random_half_rate)
|
||||
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
if options.color_noise then
|
||||
dest = data_augmentation.color_noise(dest)
|
||||
end
|
||||
if options.overlay then
|
||||
dest = data_augmentation.overlay(dest)
|
||||
end
|
||||
dest = data_augmentation.color_noise(dest, options.random_color_noise_rate)
|
||||
dest = data_augmentation.overlay(dest, options.random_overlay_rate)
|
||||
dest = data_augmentation.shift_1px(dest)
|
||||
|
||||
return dest
|
||||
|
|
|
@ -26,19 +26,19 @@ cmd:option("-method", "scale", 'method to training (noise|scale)')
|
|||
cmd:option("-noise_level", 1, '(1|2)')
|
||||
cmd:option("-style", "art", '(art|photo)')
|
||||
cmd:option("-color", 'rgb', '(y|rgb)')
|
||||
cmd:option("-color_noise", 0, 'data augmentation using color noise (1|0)')
|
||||
cmd:option("-overlay", 0, 'data augmentation using overlay (1|0)')
|
||||
cmd:option("-random_color_noise_rate", 0.0, 'data augmentation using color noise (0.0-1.0)')
|
||||
cmd:option("-random_overlay_rate", 0.0, 'data augmentation using flipped image overlay (0.0-1.0)')
|
||||
cmd:option("-random_half_rate", 0.0, 'data augmentation using half resolution image (0.0-1.0)')
|
||||
cmd:option("-scale", 2.0, 'scale factor (2)')
|
||||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-random_half", 0, 'data augmentation using half resolution image (0|1)')
|
||||
cmd:option("-crop_size", 46, 'crop size')
|
||||
cmd:option("-max_size", 256, 'if image is larger than max_size, image will be crop to max_size randomly')
|
||||
cmd:option("-batch_size", 8, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'number of total epochs to run')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
cmd:option("-jpeg_sampling_factors", 444, '(444|420)')
|
||||
cmd:option("-validation_rate", 0.05, 'validation-set rate of data')
|
||||
cmd:option("-validation_crops", 80, 'number of region per image in validation')
|
||||
cmd:option("-validation_rate", 0.05, 'validation-set rate (number_of_training_images * validation_rate > 1)')
|
||||
cmd:option("-validation_crops", 80, 'number of cropping region per image in validation')
|
||||
cmd:option("-active_cropping_rate", 0.5, 'active cropping rate')
|
||||
cmd:option("-active_cropping_tries", 10, 'active cropping tries')
|
||||
cmd:option("-nr_rate", 0.7, 'trade-off between reducing noise and erasing details (0.0-1.0)')
|
||||
|
@ -69,21 +69,6 @@ if not (settings.style == "art" or
|
|||
settings.style == "photo") then
|
||||
error(string.format("unknown style: %s", settings.style))
|
||||
end
|
||||
if settings.random_half == 1 then
|
||||
settings.random_half = true
|
||||
else
|
||||
settings.random_half = false
|
||||
end
|
||||
if settings.color_noise == 1 then
|
||||
settings.color_noise = true
|
||||
else
|
||||
settings.color_noise = false
|
||||
end
|
||||
if settings.overlay == 1 then
|
||||
settings.overlay = true
|
||||
else
|
||||
settings.overlay = false
|
||||
end
|
||||
|
||||
if settings.thread > 0 then
|
||||
torch.setnumthreads(tonumber(settings.thread))
|
||||
|
|
26
train.lua
26
train.lua
|
@ -85,20 +85,20 @@ local function transformer(x, is_validation, n, offset)
|
|||
x = compression.decompress(x)
|
||||
n = n or settings.batch_size;
|
||||
if is_validation == nil then is_validation = false end
|
||||
local color_noise = nil
|
||||
local overlay = nil
|
||||
local random_color_noise_rate = nil
|
||||
local random_overlay_rate = nil
|
||||
local active_cropping_rate = nil
|
||||
local active_cropping_tries = nil
|
||||
if is_validation then
|
||||
active_cropping_rate = 0
|
||||
active_cropping_tries = 0
|
||||
color_noise = false
|
||||
overlay = false
|
||||
random_color_noise_rate = 0.0
|
||||
random_overlay_rate = 0.0
|
||||
else
|
||||
active_cropping_rate = settings.active_cropping_rate
|
||||
active_cropping_tries = settings.active_cropping_tries
|
||||
color_noise = settings.color_noise
|
||||
overlay = settings.overlay
|
||||
random_color_noise_rate = settings.random_color_noise_rate
|
||||
random_overlay_rate = settings.random_overlay_rate
|
||||
end
|
||||
|
||||
if settings.method == "scale" then
|
||||
|
@ -106,9 +106,10 @@ local function transformer(x, is_validation, n, offset)
|
|||
settings.scale,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
{
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
|
@ -120,9 +121,10 @@ local function transformer(x, is_validation, n, offset)
|
|||
settings.noise_level,
|
||||
settings.crop_size, offset,
|
||||
n,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
{
|
||||
random_half_rate = settings.random_half_rate,
|
||||
random_color_noise_rate = random_color_noise_rate,
|
||||
random_overlay_rate = random_overlay_rate,
|
||||
max_size = settings.max_size,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
|
|
Loading…
Reference in a new issue