Merge branch 'dev'
This commit is contained in:
commit
4d3d123d72
|
@ -95,6 +95,7 @@ local function load_images(list)
|
|||
if csv_meta and csv_meta.filters then
|
||||
filters = csv_meta.filters
|
||||
end
|
||||
local basename_y = path.basename(filename)
|
||||
local im, meta = image_loader.load_byte(filename)
|
||||
local skip = false
|
||||
local alpha_color = torch.random(0, 1)
|
||||
|
@ -128,7 +129,7 @@ local function load_images(list)
|
|||
yy = iproc.rgb2y(yy)
|
||||
end
|
||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||
{data = {filters = filters, has_x = true}}})
|
||||
{data = {filters = filters, has_x = true, basename = basename_y}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
||||
end
|
||||
|
@ -144,7 +145,7 @@ local function load_images(list)
|
|||
if settings.grayscale then
|
||||
im = iproc.rgb2y(im)
|
||||
end
|
||||
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
|
||||
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
|
||||
else
|
||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
||||
end
|
||||
|
|
|
@ -79,6 +79,7 @@ cmd:option("-update_criterion", "mse", 'mse|loss')
|
|||
cmd:option("-padding", 0, 'replication padding size')
|
||||
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
|
||||
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
|
||||
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
|
||||
|
||||
local function to_bool(settings, name)
|
||||
if settings[name] == 1 then
|
||||
|
@ -99,6 +100,7 @@ to_bool(settings, "pairwise_y_binary")
|
|||
to_bool(settings, "pairwise_flip")
|
||||
to_bool(settings, "padding_y_zero")
|
||||
to_bool(settings, "grayscale")
|
||||
to_bool(settings, "validation_filename_split")
|
||||
|
||||
if settings.plot then
|
||||
require 'gnuplot'
|
||||
|
|
40
train.lua
40
train.lua
|
@ -29,6 +29,45 @@ local function save_test_user(model, rgb, file)
|
|||
end
|
||||
end
|
||||
local function split_data(x, test_size)
|
||||
if settings.validation_filename_split then
|
||||
if not (x[1][2].data and x[1][2].data.basename) then
|
||||
error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
|
||||
end
|
||||
local basename_db = {}
|
||||
for i = 1, #x do
|
||||
local meta = x[i][2].data
|
||||
if basename_db[meta.basename] then
|
||||
table.insert(basename_db[meta.basename], x[i])
|
||||
else
|
||||
basename_db[meta.basename] = {x[i]}
|
||||
end
|
||||
end
|
||||
local basename_list = {}
|
||||
for k, v in pairs(basename_db) do
|
||||
table.insert(basename_list, v)
|
||||
end
|
||||
local index = torch.randperm(#basename_list)
|
||||
local train_x = {}
|
||||
local valid_x = {}
|
||||
local pos = 1
|
||||
for i = 1, #basename_list do
|
||||
if #valid_x >= test_size then
|
||||
break
|
||||
end
|
||||
local xs = basename_list[index[pos]]
|
||||
for j = 1, #xs do
|
||||
table.insert(valid_x, xs[j])
|
||||
end
|
||||
pos = pos + 1
|
||||
end
|
||||
for i = pos, #basename_list do
|
||||
local xs = basename_list[index[i]]
|
||||
for j = 1, #xs do
|
||||
table.insert(train_x, xs[j])
|
||||
end
|
||||
end
|
||||
return train_x, valid_x
|
||||
else
|
||||
local index = torch.randperm(#x)
|
||||
local train_size = #x - test_size
|
||||
local train_x = {}
|
||||
|
@ -40,6 +79,7 @@ local function split_data(x, test_size)
|
|||
valid_x[i] = x[index[train_size + i]]
|
||||
end
|
||||
return train_x, valid_x
|
||||
end
|
||||
end
|
||||
|
||||
local g_transform_pool = nil
|
||||
|
|
Loading…
Reference in a new issue