Add support for padding in convert_data.lua
This commit is contained in:
parent
202453d6c5
commit
71a34393b8
|
@ -63,7 +63,24 @@ local function crop_if_large_pair(x, y, max_size)
|
||||||
return x, y
|
return x, y
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
local function padding_x(x, pad)
|
||||||
|
if pad > 0 then
|
||||||
|
x = iproc.padding(x, pad, pad, pad, pad)
|
||||||
|
end
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
local function padding_xy(x, y, pad, y_zero)
|
||||||
|
local scale = y:size(2) / x:size(2)
|
||||||
|
if pad > 0 then
|
||||||
|
x = iproc.padding(x, pad, pad, pad, pad)
|
||||||
|
if y_zero then
|
||||||
|
y = iproc.zero_padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
|
||||||
|
else
|
||||||
|
y = iproc.padding(y, pad * scale, pad * scale, pad * scale, pad * scale)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return x, y
|
||||||
|
end
|
||||||
local function load_images(list)
|
local function load_images(list)
|
||||||
local MARGIN = 32
|
local MARGIN = 32
|
||||||
local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
|
local csv = csvigo.load({path = list, verbose = false, mode = "raw"})
|
||||||
|
@ -105,6 +122,7 @@ local function load_images(list)
|
||||||
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
xx = alpha_util.fill(xx, meta2.alpha, alpha_color)
|
||||||
end
|
end
|
||||||
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
xx, yy = crop_if_large_pair(xx, yy, settings.max_training_image_size)
|
||||||
|
xx, yy = padding_xy(xx, yy, settings.padding, settings.padding_y_zero)
|
||||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||||
{data = {filters = filters, has_x = true}}})
|
{data = {filters = filters, has_x = true}}})
|
||||||
else
|
else
|
||||||
|
@ -113,6 +131,7 @@ local function load_images(list)
|
||||||
else
|
else
|
||||||
im = crop_if_large(im, settings.max_training_image_size)
|
im = crop_if_large(im, settings.max_training_image_size)
|
||||||
im = iproc.crop_mod4(im)
|
im = iproc.crop_mod4(im)
|
||||||
|
im = padding_x(im, settings.padding)
|
||||||
local scale = 1.0
|
local scale = 1.0
|
||||||
if settings.random_half_rate > 0.0 then
|
if settings.random_half_rate > 0.0 then
|
||||||
scale = 2.0
|
scale = 2.0
|
||||||
|
|
|
@ -80,6 +80,8 @@ function iproc.scale_with_gamma22(src, width, height, filter, blur)
|
||||||
return dest
|
return dest
|
||||||
end
|
end
|
||||||
function iproc.padding(img, w1, w2, h1, h2)
|
function iproc.padding(img, w1, w2, h1, h2)
|
||||||
|
local conversion
|
||||||
|
img, conversion = iproc.byte2float(img)
|
||||||
image = image or require 'image'
|
image = image or require 'image'
|
||||||
local dst_height = img:size(2) + h1 + h2
|
local dst_height = img:size(2) + h1 + h2
|
||||||
local dst_width = img:size(3) + w1 + w2
|
local dst_width = img:size(3) + w1 + w2
|
||||||
|
@ -88,9 +90,15 @@ function iproc.padding(img, w1, w2, h1, h2)
|
||||||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||||
flow[1]:add(-h1)
|
flow[1]:add(-h1)
|
||||||
flow[2]:add(-w1)
|
flow[2]:add(-w1)
|
||||||
return image.warp(img, flow, "simple", false, "clamp")
|
local dest = image.warp(img, flow, "simple", false, "clamp")
|
||||||
|
if conversion then
|
||||||
|
dest = iproc.float2byte(dest)
|
||||||
|
end
|
||||||
|
return dest
|
||||||
end
|
end
|
||||||
function iproc.zero_padding(img, w1, w2, h1, h2)
|
function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||||
|
local conversion
|
||||||
|
img, conversion = iproc.byte2float(img)
|
||||||
image = image or require 'image'
|
image = image or require 'image'
|
||||||
local dst_height = img:size(2) + h1 + h2
|
local dst_height = img:size(2) + h1 + h2
|
||||||
local dst_width = img:size(3) + w1 + w2
|
local dst_width = img:size(3) + w1 + w2
|
||||||
|
@ -99,7 +107,11 @@ function iproc.zero_padding(img, w1, w2, h1, h2)
|
||||||
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
flow[2] = torch.ger(torch.ones(dst_height), torch.linspace(0, dst_width - 1, dst_width))
|
||||||
flow[1]:add(-h1)
|
flow[1]:add(-h1)
|
||||||
flow[2]:add(-w1)
|
flow[2]:add(-w1)
|
||||||
return image.warp(img, flow, "simple", false, "pad", 0)
|
local dest = image.warp(img, flow, "simple", false, "pad", 0)
|
||||||
|
if conversion then
|
||||||
|
dest = iproc.float2byte(dest)
|
||||||
|
end
|
||||||
|
return dest
|
||||||
end
|
end
|
||||||
function iproc.white_noise(src, std, rgb_weights, gamma)
|
function iproc.white_noise(src, std, rgb_weights, gamma)
|
||||||
gamma = gamma or 0.454545
|
gamma = gamma or 0.454545
|
||||||
|
|
|
@ -77,6 +77,8 @@ cmd:option("-name", "user", 'model name for user method')
|
||||||
cmd:option("-gpu", 1, 'Device ID')
|
cmd:option("-gpu", 1, 'Device ID')
|
||||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||||
cmd:option("-update_criterion", "mse", 'mse|loss')
|
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||||
|
cmd:option("-padding", 0, 'replication padding size')
|
||||||
|
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
@ -95,6 +97,7 @@ to_bool(settings, "save_history")
|
||||||
to_bool(settings, "use_transparent_png")
|
to_bool(settings, "use_transparent_png")
|
||||||
to_bool(settings, "pairwise_y_binary")
|
to_bool(settings, "pairwise_y_binary")
|
||||||
to_bool(settings, "pairwise_flip")
|
to_bool(settings, "pairwise_flip")
|
||||||
|
to_bool(settings, "padding_y_zero")
|
||||||
|
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
require 'gnuplot'
|
require 'gnuplot'
|
||||||
|
|
Loading…
Reference in a new issue