diff --git a/lib/iproc.lua b/lib/iproc.lua index 1d1378b..1e39de1 100644 --- a/lib/iproc.lua +++ b/lib/iproc.lua @@ -1,5 +1,6 @@ local gm = require 'graphicsmagick' local image = require 'image' + local iproc = {} function iproc.crop_mod4(src) @@ -16,6 +17,15 @@ function iproc.crop(src, w1, h1, w2, h2) end return dest end +function iproc.crop_nocopy(src, w1, h1, w2, h2) + local dest + if src:dim() == 3 then + dest = src[{{}, { h1 + 1, h2 }, { w1 + 1, w2 }}] + else -- dim == 2 + dest = src[{{ h1 + 1, h2 }, { w1 + 1, w2 }}] + end + return dest +end function iproc.byte2float(src) local conversion = false local dest = src @@ -55,4 +65,5 @@ function iproc.padding(img, w1, w2, h1, h2) return image.warp(img, flow, "simple", false, "clamp") end + return iproc diff --git a/lib/minibatch_adam.lua b/lib/minibatch_adam.lua index af05f5a..915f1b4 100644 --- a/lib/minibatch_adam.lua +++ b/lib/minibatch_adam.lua @@ -45,7 +45,7 @@ local function minibatch_adam(model, criterion, optim.adam(feval, parameters, config) c = c + 1 - if c % 10 == 0 then + if c % 20 == 0 then collectgarbage() end end diff --git a/lib/pairwise_transform.lua b/lib/pairwise_transform.lua index 860fdb3..52bb3be 100644 --- a/lib/pairwise_transform.lua +++ b/lib/pairwise_transform.lua @@ -16,10 +16,19 @@ local function random_half(src, p) end end local function crop_if_large(src, max_size) + local tries = 4 if src:size(2) > max_size and src:size(3) > max_size then - local yi = torch.random(0, src:size(2) - max_size) - local xi = torch.random(0, src:size(3) - max_size) - return iproc.crop(src, xi, yi, xi + max_size, yi + max_size) + local rect + for i = 1, tries do + local yi = torch.random(0, src:size(2) - max_size) + local xi = torch.random(0, src:size(3) - max_size) + rect = iproc.crop(src, xi, yi, xi + max_size, yi + max_size) + -- ignore simple background + if rect:float():std() >= 0 then + break + end + end + return rect else return src end @@ -29,7 +38,7 @@ local function preprocess(src, crop_size, options) if options.random_half then dest = random_half(dest) end - dest = crop_if_large(dest, math.max(crop_size * 4, 512)) + dest = crop_if_large(dest, math.max(crop_size * 2, options.max_size)) dest = data_augmentation.flip(dest) if options.color_noise then dest = data_augmentation.color_noise(dest) @@ -52,7 +61,9 @@ local function active_cropping(x, y, size, p, tries) return xc, yc else local samples = {} - local sum_mse = 0 + local best_se = 0.0 + local best_xc, best_yc + local m = torch.FloatTensor(x:size(1), size, size) for i = 1, tries do local xi = torch.random(0, y:size(3) - (size + 1)) local yi = torch.random(0, y:size(2) - (size + 1)) @@ -60,17 +71,14 @@ local function active_cropping(x, y, size, p, tries) local yc = iproc.crop(y, xi, yi, xi + size, yi + size) local xcf = iproc.byte2float(xc) local ycf = iproc.byte2float(yc) - local mse = (xcf - ycf):pow(2):mean() - sum_mse = sum_mse + mse - table.insert(samples, {xc = xc, yc = yc, mse = mse}) + local se = m:copy(xcf):add(-1.0, ycf):pow(2):sum() + if se >= best_se then + best_xc = xcf + best_yc = ycf + best_se = se + end end - if sum_mse > 0 then - table.sort(samples, - function (a, b) - return a.mse > b.mse - end) - end - return samples[1].xc, samples[1].yc + return best_xc, best_yc end end function pairwise_transform.scale(src, scale, size, offset, n, options) @@ -83,6 +91,7 @@ function pairwise_transform.scale(src, scale, size, offset, n, options) "SincFast", -- 0.014095824314306 "Jinc", -- 0.014244299255442 } + local unstable_region_offset = 8 local downscale_filter = filters[torch.random(1, #filters)] local y = preprocess(src, size, options) assert(y:size(2) % 4 == 0 and y:size(3) % 4 == 0) @@ -90,6 +99,13 @@ function pairwise_transform.scale(src, scale, size, offset, n, options) local x = iproc.scale(iproc.scale(y, y:size(3) * down_scale, y:size(2) * down_scale, downscale_filter), y:size(3), y:size(2)) + x = iproc.crop(x, unstable_region_offset, unstable_region_offset, + x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) + y = iproc.crop(y, unstable_region_offset, unstable_region_offset, + y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) + assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) + assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) + local batch = {} for i = 1, n do local xc, yc = active_cropping(x, y, @@ -108,8 +124,10 @@ function pairwise_transform.scale(src, scale, size, offset, n, options) return batch end function pairwise_transform.jpeg_(src, quality, size, offset, n, options) + local unstable_region_offset = 8 local y = preprocess(src, size, options) local x = y + for i = 1, #quality do x = gm.Image(x, "RGB", "DHW") x:format("jpeg") @@ -122,7 +140,12 @@ function pairwise_transform.jpeg_(src, quality, size, offset, n, options) x:fromBlob(blob, len) x = x:toTensor("byte", "RGB", "DHW") end - -- TODO: use shift_1px after compression? + x = iproc.crop(x, unstable_region_offset, unstable_region_offset, + x:size(3) - unstable_region_offset, x:size(2) - unstable_region_offset) + y = iproc.crop(y, unstable_region_offset, unstable_region_offset, + y:size(3) - unstable_region_offset, y:size(2) - unstable_region_offset) + assert(x:size(2) % 4 == 0 and x:size(3) % 4 == 0) + assert(x:size(1) == y:size(1) and x:size(2) == y:size(2) and x:size(3) == y:size(3)) local batch = {} for i = 1, n do @@ -152,7 +175,7 @@ function pairwise_transform.jpeg(src, category, level, size, offset, n, options) end elseif level == 2 then local r = torch.uniform() - if torch.uniform() > 0.8 then + if torch.uniform() > 0.9 then return pairwise_transform.jpeg_(src, {}, size, offset, n, options) else diff --git a/lib/settings.lua b/lib/settings.lua index 191a23b..3ce097c 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -32,7 +32,7 @@ cmd:option("-scale", 2.0, 'scale') cmd:option("-learning_rate", 0.00025, 'learning rate for adam') cmd:option("-random_half", 1, 'enable data augmentation using half resolution image (0|1)') cmd:option("-crop_size", 128, 'crop size') -cmd:option("-max_size", -1, 'crop if image size larger then this value.') +cmd:option("-max_size", 512, 'crop if image size larger then this value.') cmd:option("-batch_size", 2, 'mini batch size') cmd:option("-epoch", 200, 'epoch') cmd:option("-thread", -1, 'number of CPU threads') diff --git a/train.lua b/train.lua index 978dc92..b315c6e 100644 --- a/train.lua +++ b/train.lua @@ -91,7 +91,7 @@ local function transformer(x, is_validation, n, offset) local active_cropping_tries = nil if is_validation then - active_cropping_rate = 0.0 + active_cropping_rate = 0 active_cropping_tries = 0 color_noise = false overlay = false @@ -110,6 +110,7 @@ local function transformer(x, is_validation, n, offset) { color_noise = color_noise, overlay = overlay, random_half = settings.random_half, + max_size = settings.max_size, active_cropping_rate = active_cropping_rate, active_cropping_tries = active_cropping_tries, rgb = (settings.color == "rgb") @@ -122,10 +123,11 @@ local function transformer(x, is_validation, n, offset) n, { color_noise = color_noise, overlay = overlay, + random_half = settings.random_half, + max_size = settings.max_size, + jpeg_sampling_factors = settings.jpeg_sampling_factors, active_cropping_rate = active_cropping_rate, active_cropping_tries = active_cropping_tries, - random_half = settings.random_half, - jpeg_sampling_factors = settings.jpeg_sampling_factors, rgb = (settings.color == "rgb") }) end