Fix the number of threads
This commit is contained in:
parent
c2e4bb4380
commit
33e6bc888e
1 changed files with 13 additions and 3 deletions
16
train.lua
16
train.lua
|
@ -44,8 +44,12 @@ end
|
|||
|
||||
local g_transform_pool = nil
|
||||
local function transform_pool_init(has_resize, offset)
|
||||
local nthread = torch.getnumthreads()
|
||||
if (settings.thread > 0) then
|
||||
nthread = settings.thread
|
||||
end
|
||||
g_transform_pool = threads.Threads(
|
||||
torch.getnumthreads(),
|
||||
nthread,
|
||||
function(threadid)
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
|
@ -161,6 +165,9 @@ end
|
|||
|
||||
local function make_validation_set(x, n, patches)
|
||||
local nthread = torch.getnumthreads()
|
||||
if (settings.thread > 0) then
|
||||
nthread = settings.thread
|
||||
end
|
||||
n = n or 4
|
||||
local validation_patches = math.min(16, patches or 16)
|
||||
local data = {}
|
||||
|
@ -255,10 +262,13 @@ end
|
|||
|
||||
local function resampling(x, y, train_x)
|
||||
local c = 1
|
||||
local nthread = torch.getnumthreads()
|
||||
local shuffle = torch.randperm(#train_x)
|
||||
|
||||
local nthread = torch.getnumthreads()
|
||||
if (settings.thread > 0) then
|
||||
nthread = settings.thread
|
||||
end
|
||||
torch.setnumthreads(1) -- 1
|
||||
|
||||
for t = 1, #train_x do
|
||||
local input = train_x[shuffle[t]]
|
||||
g_transform_pool:addjob(
|
||||
|
|
Loading…
Reference in a new issue