From bb0fc3a1d3071a0019444f92ce40ad8fa0172c85 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 15 Apr 2017 16:29:38 +0900 Subject: [PATCH] Add -validation_filename_split option --- convert_data.lua | 5 ++-- lib/settings.lua | 2 ++ train.lua | 60 ++++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 55 insertions(+), 12 deletions(-) diff --git a/convert_data.lua b/convert_data.lua index 11d2f62..5375243 100644 --- a/convert_data.lua +++ b/convert_data.lua @@ -95,6 +95,7 @@ local function load_images(list) if csv_meta and csv_meta.filters then filters = csv_meta.filters end + local basename_y = path.basename(filename) local im, meta = image_loader.load_byte(filename) local skip = false local alpha_color = torch.random(0, 1) @@ -128,7 +129,7 @@ local function load_images(list) yy = iproc.rgb2y(yy) end table.insert(x, {{y = compression.compress(yy), x = compression.compress(xx)}, - {data = {filters = filters, has_x = true}}}) + {data = {filters = filters, has_x = true, basename = basename_y}}}) else io.stderr:write(string.format("\n%s: skip: load error.\n", csv_meta.x)) end @@ -144,7 +145,7 @@ local function load_images(list) if settings.grayscale then im = iproc.rgb2y(im) end - table.insert(x, {compression.compress(im), {data = {filters = filters}}}) + table.insert(x, {compression.compress(im), {data = {filters = filters, basename = basename_y}}}) else io.stderr:write(string.format("\n%s: skip: image is too small (%d > size).\n", filename, settings.crop_size * scale + MARGIN)) end diff --git a/lib/settings.lua b/lib/settings.lua index 4248353..88a913d 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -79,6 +79,7 @@ cmd:option("-update_criterion", "mse", 'mse|loss') cmd:option("-padding", 0, 'replication padding size') cmd:option("-padding_y_zero", 0, 'zero padding y for segmentation (0|1)') cmd:option("-grayscale", 0, 'grayscale x&y (0|1)') +cmd:option("-validation_filename_split", 0, 'make validation-set based on filename(basename)') local function to_bool(settings, name) if settings[name] == 1 then @@ -99,6 +100,7 @@ to_bool(settings, "pairwise_y_binary") to_bool(settings, "pairwise_flip") to_bool(settings, "padding_y_zero") to_bool(settings, "grayscale") +to_bool(settings, "validation_filename_split") if settings.plot then require 'gnuplot' diff --git a/train.lua b/train.lua index aa8137d..3ba8e48 100644 --- a/train.lua +++ b/train.lua @@ -29,17 +29,57 @@ local function save_test_user(model, rgb, file) end end local function split_data(x, test_size) - local index = torch.randperm(#x) - local train_size = #x - test_size - local train_x = {} - local valid_x = {} - for i = 1, train_size do - train_x[i] = x[index[i]] + if settings.validation_filename_split then + if not (x[1][2].data and x[1][2].data.basename) then + error("`images.t` does not have basename info. You need to re-run `convert_data.lua`.") + end + local basename_db = {} + for i = 1, #x do + local meta = x[i][2].data + if basename_db[meta.basename] then + table.insert(basename_db[meta.basename], x[i]) + else + basename_db[meta.basename] = {x[i]} + end + end + local basename_list = {} + for k, v in pairs(basename_db) do + table.insert(basename_list, v) + end + local index = torch.randperm(#basename_list) + local train_x = {} + local valid_x = {} + local pos = 1 + for i = 1, #basename_list do + if #valid_x >= test_size then + break + end + local xs = basename_list[index[pos]] + for j = 1, #xs do + table.insert(valid_x, xs[j]) + end + pos = pos + 1 + end + for i = pos, #basename_list do + local xs = basename_list[index[i]] + for j = 1, #xs do + table.insert(train_x, xs[j]) + end + end + return train_x, valid_x + else + local index = torch.randperm(#x) + local train_size = #x - test_size + local train_x = {} + local valid_x = {} + for i = 1, train_size do + train_x[i] = x[index[i]] + end + for i = 1, test_size do + valid_x[i] = x[index[train_size + i]] + end + return train_x, valid_x end - for i = 1, test_size do - valid_x[i] = x[index[train_size + i]] - end - return train_x, valid_x end local g_transform_pool = nil