diff --git a/train.lua b/train.lua index d2f652c..5c28f0d 100644 --- a/train.lua +++ b/train.lua @@ -43,11 +43,15 @@ local function split_data(x, test_size) end local g_transform_pool = nil +local g_mutex = nil +local g_mutex_id = nil local function transform_pool_init(has_resize, offset) local nthread = torch.getnumthreads() if (settings.thread > 0) then nthread = settings.thread end + g_mutex = threads.Mutex() + g_mutex_id = g_mutex:id() g_transform_pool = threads.Threads( nthread, function(threadid) @@ -56,10 +60,13 @@ local function transform_pool_init(has_resize, offset) package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path require 'nn' require 'cunn' + local threads = require 'threads' + local compression = require 'compression' local pairwise_transform = require 'pairwise_transform' function transformer(x, is_validation, n) + local mutex = threads.Mutex(g_mutex_id) local meta = {data = {}} local y = nil if type(x) == "table" and type(x[2]) == "table" then @@ -92,6 +99,7 @@ local function transform_pool_init(has_resize, offset) end if settings.method == "scale" then local conf = tablex.update({ + mutex = mutex, downsampling_filters = settings.downsampling_filters, random_half_rate = settings.random_half_rate, random_color_noise_rate = random_color_noise_rate, @@ -114,6 +122,7 @@ local function transform_pool_init(has_resize, offset) n, conf) elseif settings.method == "noise" then local conf = tablex.update({ + mutex = mutex, random_half_rate = settings.random_half_rate, random_color_noise_rate = random_color_noise_rate, random_overlay_rate = random_overlay_rate, @@ -135,6 +144,7 @@ local function transform_pool_init(has_resize, offset) n, conf) elseif settings.method == "noise_scale" then local conf = tablex.update({ + mutex = mutex, downsampling_filters = settings.downsampling_filters, random_half_rate = settings.random_half_rate, random_color_noise_rate = random_color_noise_rate,