1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Merge branch 'dev'

This commit is contained in:
nagadomi 2017-04-15 16:30:34 +09:00
commit 4d3d123d72
3 changed files with 55 additions and 12 deletions

View file

@ -95,6 +95,7 @@ local function load_images(list)
if csv_meta and csv_meta.filters then
filters = csv_meta.filters
end
local basename_y = path.basename(filename)
local im, meta = image_loader.load_byte(filename)
local skip = false
local alpha_color = torch.random(0, 1)
@ -128,7 +129,7 @@ local function load_images(list)
yy = iproc.rgb2y(yy)
end
table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)},
{data = {filters = filters, has_x = true}}})
{data = {filters = filters, has_x = true, basename = basename_y}}})
else
io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x))
end
@ -144,7 +145,7 @@ local function load_images(list)
if settings.grayscale then
im = iproc.rgb2y(im)
end
table.insert(x, {compression.compress(im), {data = {filters = filters}}})
table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}})
else
io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN))
end

View file

@ -79,6 +79,7 @@ cmd:option("-update_criterion", "mse", 'mse|loss')
cmd:option("-padding", 0, 'replication padding size')
cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)')
cmd:option("-grayscale", 0, 'grayscale x&y (0|1)')
cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)')
local function to_bool(settings, name)
if settings[name] == 1 then
@ -99,6 +100,7 @@ to_bool(settings, "pairwise_y_binary")
to_bool(settings, "pairwise_flip")
to_bool(settings, "padding_y_zero")
to_bool(settings, "grayscale")
to_bool(settings, "validation_filename_split")
if settings.plot then
require 'gnuplot'

View file

@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file)
end
end
local function split_data(x, test_size)
local index = torch.randperm(#x)
local train_size = #x - test_size
local train_x = {}
local valid_x = {}
for i = 1, train_size do
train_x[i] = x[index[i]]
if settings.validation_filename_split then
if not (x[1][2].data and x[1][2].data.basename) then
error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.")
end
local basename_db = {}
for i = 1, #x do
local meta = x[i][2].data
if basename_db[meta.basename] then
table.insert(basename_db[meta.basename], x[i])
else
basename_db[meta.basename] = {x[i]}
end
end
local basename_list = {}
for k, v in pairs(basename_db) do
table.insert(basename_list, v)
end
local index = torch.randperm(#basename_list)
local train_x = {}
local valid_x = {}
local pos = 1
for i = 1, #basename_list do
if #valid_x >= test_size then
break
end
local xs = basename_list[index[pos]]
for j = 1, #xs do
table.insert(valid_x, xs[j])
end
pos = pos + 1
end
for i = pos, #basename_list do
local xs = basename_list[index[i]]
for j = 1, #xs do
table.insert(train_x, xs[j])
end
end
return train_x, valid_x
else
local index = torch.randperm(#x)
local train_size = #x - test_size
local train_x = {}
local valid_x = {}
for i = 1, train_size do
train_x[i] = x[index[i]]
end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end
for i = 1, test_size do
valid_x[i] = x[index[train_size + i]]
end
return train_x, valid_x
end
local g_transform_pool = nil