diff --git a/train.lua b/train.lua index 79afe12..c5fe873 100644 --- a/train.lua +++ b/train.lua @@ -44,8 +44,12 @@ end local g_transform_pool = nil local function transform_pool_init(has_resize, offset) + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end g_transform_pool = threads.Threads( - torch.getnumthreads(), + nthread, function(threadid) require 'pl' local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() @@ -161,6 +165,9 @@ end local function make_validation_set(x, n, patches) local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end n = n or 4 local validation_patches = math.min(16, patches or 16) local data = {} @@ -255,10 +262,13 @@ end local function resampling(x, y, train_x) local c = 1 - local nthread = torch.getnumthreads() local shuffle = torch.randperm(#train_x) - + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end torch.setnumthreads(1) -- 1 + for t = 1, #train_x do local input = train_x[shuffle[t]] g_transform_pool:addjob(