Fix the number of threads
This commit is contained in:
parent
c2e4bb4380
commit
33e6bc888e
1 changed files with 13 additions and 3 deletions
16
train.lua
16
train.lua
|
@ -44,8 +44,12 @@ end
|
||||||
|
|
||||||
local g_transform_pool = nil
|
local g_transform_pool = nil
|
||||||
local function transform_pool_init(has_resize, offset)
|
local function transform_pool_init(has_resize, offset)
|
||||||
|
local nthread = torch.getnumthreads()
|
||||||
|
if (settings.thread > 0) then
|
||||||
|
nthread = settings.thread
|
||||||
|
end
|
||||||
g_transform_pool = threads.Threads(
|
g_transform_pool = threads.Threads(
|
||||||
torch.getnumthreads(),
|
nthread,
|
||||||
function(threadid)
|
function(threadid)
|
||||||
require 'pl'
|
require 'pl'
|
||||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||||
|
@ -161,6 +165,9 @@ end
|
||||||
|
|
||||||
local function make_validation_set(x, n, patches)
|
local function make_validation_set(x, n, patches)
|
||||||
local nthread = torch.getnumthreads()
|
local nthread = torch.getnumthreads()
|
||||||
|
if (settings.thread > 0) then
|
||||||
|
nthread = settings.thread
|
||||||
|
end
|
||||||
n = n or 4
|
n = n or 4
|
||||||
local validation_patches = math.min(16, patches or 16)
|
local validation_patches = math.min(16, patches or 16)
|
||||||
local data = {}
|
local data = {}
|
||||||
|
@ -255,10 +262,13 @@ end
|
||||||
|
|
||||||
local function resampling(x, y, train_x)
|
local function resampling(x, y, train_x)
|
||||||
local c = 1
|
local c = 1
|
||||||
local nthread = torch.getnumthreads()
|
|
||||||
local shuffle = torch.randperm(#train_x)
|
local shuffle = torch.randperm(#train_x)
|
||||||
|
local nthread = torch.getnumthreads()
|
||||||
|
if (settings.thread > 0) then
|
||||||
|
nthread = settings.thread
|
||||||
|
end
|
||||||
torch.setnumthreads(1) -- 1
|
torch.setnumthreads(1) -- 1
|
||||||
|
|
||||||
for t = 1, #train_x do
|
for t = 1, #train_x do
|
||||||
local input = train_x[shuffle[t]]
|
local input = train_x[shuffle[t]]
|
||||||
g_transform_pool:addjob(
|
g_transform_pool:addjob(
|
||||||
|
|
Loading…
Reference in a new issue