From 33e6bc888e222bd77061f3b4bae9b6517234d0d8 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 11 Sep 2016 20:56:56 +0900 Subject: [PATCH] Fix the number of threads --- train.lua | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/train.lua b/train.lua index 79afe12..c5fe873 100644 --- a/train.lua +++ b/train.lua @@ -44,8 +44,12 @@ end local g_transform_pool = nil local function transform_pool_init(has_resize, offset) + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end g_transform_pool = threads.Threads( - torch.getnumthreads(), + nthread, function(threadid) require 'pl' local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() @@ -161,6 +165,9 @@ end local function make_validation_set(x, n, patches) local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end n = n or 4 local validation_patches = math.min(16, patches or 16) local data = {} @@ -255,10 +262,13 @@ end local function resampling(x, y, train_x) local c = 1 - local nthread = torch.getnumthreads() local shuffle = torch.randperm(#train_x) - + local nthread = torch.getnumthreads() + if (settings.thread > 0) then + nthread = settings.thread + end torch.setnumthreads(1) -- 1 + for t = 1, #train_x do local input = train_x[shuffle[t]] g_transform_pool:addjob(