Add -validation_filename_split option
This commit is contained in:
parent
88e3322296
commit
bb0fc3a1d3
|
@ -95,6 +95,7 @@ local function load_images(list)
|
||||||
if csv_meta and csv_meta.filters then
|
if csv_meta and csv_meta.filters then
|
||||||
filters = csv_meta.filters
|
filters = csv_meta.filters
|
||||||
end
|
end
|
||||||
|
local basename_y = path.basename(filename)
|
||||||
local im, meta = image_loader.load_byte(filename)
|
local im, meta = image_loader.load_byte(filename)
|
||||||
local skip = false
|
local skip = false
|
||||||
local alpha_color = torch.random(0, 1)
|
local alpha_color = torch.random(0, 1)
|
||||||
|
@ -128,7 +129,7 @@ local function load_images(list)
|
||||||
yy = iproc.rgb2y(yy)
|
yy = iproc.rgb2y(yy)
|
||||||
end
|
end
|
||||||
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
|
||||||
{data = {filters = filters, has_x = true}}})
|
{data = {filters = filters, has_x = true, basename = basename_y}}})
|
||||||
else
|
else
|
||||||
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
|
||||||
end
|
end
|
||||||
|
@ -144,7 +145,7 @@ local function load_images(list)
|
||||||
if settings.grayscale then
|
if settings.grayscale then
|
||||||
im = iproc.rgb2y(im)
|
im = iproc.rgb2y(im)
|
||||||
end
|
end
|
||||||
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
|
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
|
||||||
else
|
else
|
||||||
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
|
||||||
end
|
end
|
||||||
|
|
|
@ -79,6 +79,7 @@ cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||||
cmd:option("-padding", 0, 'replication padding size')
|
cmd:option("-padding", 0, 'replication padding size')
|
||||||
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
|
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
|
||||||
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
|
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
|
||||||
|
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
|
||||||
|
|
||||||
local function to_bool(settings, name)
|
local function to_bool(settings, name)
|
||||||
if settings[name] == 1 then
|
if settings[name] == 1 then
|
||||||
|
@ -99,6 +100,7 @@ to_bool(settings, "pairwise_y_binary")
|
||||||
to_bool(settings, "pairwise_flip")
|
to_bool(settings, "pairwise_flip")
|
||||||
to_bool(settings, "padding_y_zero")
|
to_bool(settings, "padding_y_zero")
|
||||||
to_bool(settings, "grayscale")
|
to_bool(settings, "grayscale")
|
||||||
|
to_bool(settings, "validation_filename_split")
|
||||||
|
|
||||||
if settings.plot then
|
if settings.plot then
|
||||||
require 'gnuplot'
|
require 'gnuplot'
|
||||||
|
|
60
train.lua
60
train.lua
|
@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function split_data(x, test_size)
|
local function split_data(x, test_size)
|
||||||
local index = torch.randperm(#x)
|
if settings.validation_filename_split then
|
||||||
local train_size = #x - test_size
|
if not (x[1][2].data and x[1][2].data.basename) then
|
||||||
local train_x = {}
|
error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
|
||||||
local valid_x = {}
|
end
|
||||||
for i = 1, train_size do
|
local basename_db = {}
|
||||||
train_x[i] = x[index[i]]
|
for i = 1, #x do
|
||||||
|
local meta = x[i][2].data
|
||||||
|
if basename_db[meta.basename] then
|
||||||
|
table.insert(basename_db[meta.basename], x[i])
|
||||||
|
else
|
||||||
|
basename_db[meta.basename] = {x[i]}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
local basename_list = {}
|
||||||
|
for k, v in pairs(basename_db) do
|
||||||
|
table.insert(basename_list, v)
|
||||||
|
end
|
||||||
|
local index = torch.randperm(#basename_list)
|
||||||
|
local train_x = {}
|
||||||
|
local valid_x = {}
|
||||||
|
local pos = 1
|
||||||
|
for i = 1, #basename_list do
|
||||||
|
if #valid_x >= test_size then
|
||||||
|
break
|
||||||
|
end
|
||||||
|
local xs = basename_list[index[pos]]
|
||||||
|
for j = 1, #xs do
|
||||||
|
table.insert(valid_x, xs[j])
|
||||||
|
end
|
||||||
|
pos = pos + 1
|
||||||
|
end
|
||||||
|
for i = pos, #basename_list do
|
||||||
|
local xs = basename_list[index[i]]
|
||||||
|
for j = 1, #xs do
|
||||||
|
table.insert(train_x, xs[j])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return train_x, valid_x
|
||||||
|
else
|
||||||
|
local index = torch.randperm(#x)
|
||||||
|
local train_size = #x - test_size
|
||||||
|
local train_x = {}
|
||||||
|
local valid_x = {}
|
||||||
|
for i = 1, train_size do
|
||||||
|
train_x[i] = x[index[i]]
|
||||||
|
end
|
||||||
|
for i = 1, test_size do
|
||||||
|
valid_x[i] = x[index[train_size + i]]
|
||||||
|
end
|
||||||
|
return train_x, valid_x
|
||||||
end
|
end
|
||||||
for i = 1, test_size do
|
|
||||||
valid_x[i] = x[index[train_size + i]]
|
|
||||||
end
|
|
||||||
return train_x, valid_x
|
|
||||||
end
|
end
|
||||||
|
|
||||||
local g_transform_pool = nil
|
local g_transform_pool = nil
|
||||||
|
|
Loading…
Reference in a new issue