1
0
Fork 0
mirror of synced 2024-06-14 00:44:32 +12:00
This commit is contained in:
nagadomi 2015-11-03 06:10:44 +09:00
parent 425898a3aa
commit b35a9ae7d7
5 changed files with 58 additions and 22 deletions

View file

@ -1,5 +1,6 @@
local gm = require 'graphicsmagick'
local image = require 'image'
local iproc = {}
function iproc.crop_mod4(src)
@ -16,6 +17,15 @@ function iproc.crop(src, w1, h1, w2, h2)
end
return dest
end
function iproc.crop_nocopy(src, w1, h1, w2, h2)
local dest
if src:dim() == 3 then
dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}]
else -- dim == 2
dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}]
end
return dest
end
function iproc.byte2float(src)
local conversion = false
local dest = src
@ -55,4 +65,5 @@ function iproc.padding(img, w1, w2, h1, h2)
return image.warp(img, flow, "simple", false, "clamp")
end
return iproc

View file

@ -45,7 +45,7 @@ local function minibatch_adam(model, criterion,
optim.adam(feval, parameters, config)
c = c + 1
if c % 10 == 0 then
if c % 20 == 0 then
collectgarbage()
end
end

View file

@ -16,10 +16,19 @@ local function random_half(src, p)
end
end
local function crop_if_large(src, max_size)
local tries = 4
if src:size(2) > max_size and src:size(3) > max_size then
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
return iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
local rect
for i = 1, tries do
local yi = torch.random(0, src:size(2) - max_size)
local xi = torch.random(0, src:size(3) - max_size)
rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size)
-- ignore simple background
if rect:float():std() >= 0 then
break
end
end
return rect
else
return src
end
@ -29,7 +38,7 @@ local function preprocess(src, crop_size, options)
if options.random_half then
dest = random_half(dest)
end
dest = crop_if_large(dest, math.max(crop_size * 4, 512))
dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size))
dest = data_augmentation.flip(dest)
if options.color_noise then
dest = data_augmentation.color_noise(dest)
@ -52,7 +61,9 @@ local function active_cropping(x, y, size, p, tries)
return xc, yc
else
local samples = {}
local sum_mse = 0
local best_se = 0.0
local best_xc, best_yc
local m = torch.FloatTensor(x:size(1), size, size)
for i = 1, tries do
local xi = torch.random(0, y:size(3) - (size + 1))
local yi = torch.random(0, y:size(2) - (size + 1))
@ -60,17 +71,14 @@ local function active_cropping(x, y, size, p, tries)
local yc = iproc.crop(y, xi, yi, xi + size, yi + size)
local xcf = iproc.byte2float(xc)
local ycf = iproc.byte2float(yc)
local mse = (xcf - ycf):pow(2):mean()
sum_mse = sum_mse + mse
table.insert(samples, {xc = xc, yc = yc, mse = mse})
local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum()
if se >= best_se then
best_xc = xcf
best_yc = ycf
best_se = se
end
end
if sum_mse > 0 then
table.sort(samples,
function (a, b)
return a.mse > b.mse
end)
end
return samples[1].xc, samples[1].yc
return best_xc, best_yc
end
end
function pairwise_transform.scale(src, scale, size, offset, n, options)
@ -83,6 +91,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
"SincFast", -- 0.014095824314306
"Jinc", -- 0.014244299255442
}
local unstable_region_offset = 8
local downscale_filter = filters[torch.random(1, #filters)]
local y = preprocess(src, size, options)
assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0)
@ -90,6 +99,13 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale,
y:size(2) * down_scale, downscale_filter),
y:size(3), y:size(2))
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
local xc, yc = active_cropping(x, y,
@ -108,8 +124,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options)
return batch
end
function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
local unstable_region_offset = 8
local y = preprocess(src, size, options)
local x = y
for i = 1, #quality do
x = gm.Image(x, "RGB", "DHW")
x:format("jpeg")
@ -122,7 +140,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options)
x:fromBlob(blob, len)
x = x:toTensor("byte", "RGB", "DHW")
end
-- TODO: use shift_1px after compression?
x = iproc.crop(x, unstable_region_offset, unstable_region_offset,
x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset)
y = iproc.crop(y, unstable_region_offset, unstable_region_offset,
y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset)
assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0)
assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3))
local batch = {}
for i = 1, n do
@ -152,7 +175,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options)
end
elseif level == 2 then
local r = torch.uniform()
if torch.uniform() > 0.8 then
if torch.uniform() > 0.9 then
return pairwise_transform.jpeg_(src, {},
size, offset, n, options)
else

View file

@ -32,7 +32,7 @@ cmd:option("-scale", 2.0, 'scale')
cmd:option("-learning_rate", 0.00025, 'learning rate for adam')
cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)')
cmd:option("-crop_size", 128, 'crop size')
cmd:option("-max_size", -1, 'crop if image size larger then this value.')
cmd:option("-max_size", 512, 'crop if image size larger then this value.')
cmd:option("-batch_size", 2, 'mini batch size')
cmd:option("-epoch", 200, 'epoch')
cmd:option("-thread", -1, 'number of CPU threads')

View file

@ -91,7 +91,7 @@ local function transformer(x, is_validation, n, offset)
local active_cropping_tries = nil
if is_validation then
active_cropping_rate = 0.0
active_cropping_rate = 0
active_cropping_tries = 0
color_noise = false
overlay = false
@ -110,6 +110,7 @@ local function transformer(x, is_validation, n, offset)
{ color_noise = color_noise,
overlay = overlay,
random_half = settings.random_half,
max_size = settings.max_size,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
rgb = (settings.color == "rgb")
@ -122,10 +123,11 @@ local function transformer(x, is_validation, n, offset)
n,
{ color_noise = color_noise,
overlay = overlay,
random_half = settings.random_half,
max_size = settings.max_size,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
active_cropping_rate = active_cropping_rate,
active_cropping_tries = active_cropping_tries,
random_half = settings.random_half,
jpeg_sampling_factors = settings.jpeg_sampling_factors,
rgb = (settings.color == "rgb")
})
end