tuning
This commit is contained in:
parent
425898a3aa
commit
b35a9ae7d7
5 changed files with 58 additions and 22 deletions
|
@ -1,5 +1,6 @@
|
|||
local gm = require 'graphicsmagick'
|
||||
local image = require 'image'
|
||||
|
||||
local iproc = {}
|
||||
|
||||
function iproc.crop_mod4(src)
|
||||
|
@ -16,6 +17,15 @@ function iproc.crop(src, w1, h1, w2, h2)
|
|||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.crop_nocopy(src, w1, h1, w2, h2)
|
||||
local dest
|
||||
if src:dim() == 3 then
|
||||
dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
|
||||
else -- dim == 2
|
||||
dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
|
||||
end
|
||||
return dest
|
||||
end
|
||||
function iproc.byte2float(src)
|
||||
local conversion = false
|
||||
local dest = src
|
||||
|
@ -55,4 +65,5 @@ function iproc.padding(img, w1, w2, h1, h2)
|
|||
return image.warp(img, flow, "simple", false, "clamp")
|
||||
end
|
||||
|
||||
|
||||
return iproc
|
||||
|
|
|
@ -45,7 +45,7 @@ local function minibatch_adam(model, criterion,
|
|||
optim.adam(feval, parameters, config)
|
||||
|
||||
c = c + 1
|
||||
if c % 10 == 0 then
|
||||
if c % 20 == 0 then
|
||||
collectgarbage()
|
||||
end
|
||||
end
|
||||
|
|
|
@ -16,10 +16,19 @@ local function random_half(src, p)
|
|||
end
|
||||
end
|
||||
local function crop_if_large(src, max_size)
|
||||
local tries = 4
|
||||
if src:size(2) > max_size and src:size(3) > max_size then
|
||||
local rect
|
||||
for i = 1, tries do
|
||||
local yi = torch.random(0, src:size(2) - max_size)
|
||||
local xi = torch.random(0, src:size(3) - max_size)
|
||||
return iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
|
||||
-- ignore simple background
|
||||
if rect:float():std() >= 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
return rect
|
||||
else
|
||||
return src
|
||||
end
|
||||
|
@ -29,7 +38,7 @@ local function preprocess(src, crop_size, options)
|
|||
if options.random_half then
|
||||
dest = random_half(dest)
|
||||
end
|
||||
dest = crop_if_large(dest, math.max(crop_size * 4, 512))
|
||||
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
|
||||
dest = data_augmentation.flip(dest)
|
||||
if options.color_noise then
|
||||
dest = data_augmentation.color_noise(dest)
|
||||
|
@ -52,7 +61,9 @@ local function active_cropping(x, y, size, p, tries)
|
|||
return xc, yc
|
||||
else
|
||||
local samples = {}
|
||||
local sum_mse = 0
|
||||
local best_se = 0.0
|
||||
local best_xc, best_yc
|
||||
local m = torch.FloatTensor(x:size(1), size, size)
|
||||
for i = 1, tries do
|
||||
local xi = torch.random(0, y:size(3) - (size + 1))
|
||||
local yi = torch.random(0, y:size(2) - (size + 1))
|
||||
|
@ -60,17 +71,14 @@ local function active_cropping(x, y, size, p, tries)
|
|||
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
|
||||
local xcf = iproc.byte2float(xc)
|
||||
local ycf = iproc.byte2float(yc)
|
||||
local mse = (xcf - ycf):pow(2):mean()
|
||||
sum_mse = sum_mse + mse
|
||||
table.insert(samples, {xc = xc, yc = yc, mse = mse})
|
||||
local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
|
||||
if se >= best_se then
|
||||
best_xc = xcf
|
||||
best_yc = ycf
|
||||
best_se = se
|
||||
end
|
||||
if sum_mse > 0 then
|
||||
table.sort(samples,
|
||||
function (a, b)
|
||||
return a.mse > b.mse
|
||||
end)
|
||||
end
|
||||
return samples[1].xc, samples[1].yc
|
||||
return best_xc, best_yc
|
||||
end
|
||||
end
|
||||
function pairwise_transform.scale(src, scale, size, offset, n, options)
|
||||
|
@ -83,6 +91,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
"SincFast", -- 0.014095824314306
|
||||
"Jinc", -- 0.014244299255442
|
||||
}
|
||||
local unstable_region_offset = 8
|
||||
local downscale_filter = filters[torch.random(1, #filters)]
|
||||
local y = preprocess(src, size, options)
|
||||
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
|
||||
|
@ -90,6 +99,13 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
|
||||
y:size(2) * down_scale, downscale_filter),
|
||||
y:size(3), y:size(2))
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
local xc, yc = active_cropping(x, y,
|
||||
|
@ -108,8 +124,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
|
|||
return batch
|
||||
end
|
||||
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
||||
local unstable_region_offset = 8
|
||||
local y = preprocess(src, size, options)
|
||||
local x = y
|
||||
|
||||
for i = 1, #quality do
|
||||
x = gm.Image(x, "RGB", "DHW")
|
||||
x:format("jpeg")
|
||||
|
@ -122,7 +140,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
|
|||
x:fromBlob(blob, len)
|
||||
x = x:toTensor("byte", "RGB", "DHW")
|
||||
end
|
||||
-- TODO: use shift_1px after compression?
|
||||
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
|
||||
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
|
||||
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
|
||||
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
|
||||
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
|
||||
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
|
||||
|
||||
local batch = {}
|
||||
for i = 1, n do
|
||||
|
@ -152,7 +175,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
|
|||
end
|
||||
elseif level == 2 then
|
||||
local r = torch.uniform()
|
||||
if torch.uniform() > 0.8 then
|
||||
if torch.uniform() > 0.9 then
|
||||
return pairwise_transform.jpeg_(src, {},
|
||||
size, offset, n, options)
|
||||
else
|
||||
|
|
|
@ -32,7 +32,7 @@ cmd:option("-scale", 2.0, 'scale')
|
|||
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
|
||||
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
|
||||
cmd:option("-crop_size", 128, 'crop size')
|
||||
cmd:option("-max_size", -1, 'crop if image size larger then this value.')
|
||||
cmd:option("-max_size", 512, 'crop if image size larger then this value.')
|
||||
cmd:option("-batch_size", 2, 'mini batch size')
|
||||
cmd:option("-epoch", 200, 'epoch')
|
||||
cmd:option("-thread", -1, 'number of CPU threads')
|
||||
|
|
|
@ -91,7 +91,7 @@ local function transformer(x, is_validation, n, offset)
|
|||
local active_cropping_tries = nil
|
||||
|
||||
if is_validation then
|
||||
active_cropping_rate = 0.0
|
||||
active_cropping_rate = 0
|
||||
active_cropping_tries = 0
|
||||
color_noise = false
|
||||
overlay = false
|
||||
|
@ -110,6 +110,7 @@ local function transformer(x, is_validation, n, offset)
|
|||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
max_size = settings.max_size,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
rgb = (settings.color == "rgb")
|
||||
|
@ -122,10 +123,11 @@ local function transformer(x, is_validation, n, offset)
|
|||
n,
|
||||
{ color_noise = color_noise,
|
||||
overlay = overlay,
|
||||
random_half = settings.random_half,
|
||||
max_size = settings.max_size,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
active_cropping_rate = active_cropping_rate,
|
||||
active_cropping_tries = active_cropping_tries,
|
||||
random_half = settings.random_half,
|
||||
jpeg_sampling_factors = settings.jpeg_sampling_factors,
|
||||
rgb = (settings.color == "rgb")
|
||||
})
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue