mutex
This commit is contained in:
parent
5a3d012f4e
commit
a14e6acec3
1 changed files with 10 additions and 0 deletions
10
train.lua
10
train.lua
|
@ -43,11 +43,15 @@ local function split_data(x, test_size)
|
||||||
end
|
end
|
||||||
|
|
||||||
local g_transform_pool = nil
|
local g_transform_pool = nil
|
||||||
|
local g_mutex = nil
|
||||||
|
local g_mutex_id = nil
|
||||||
local function transform_pool_init(has_resize, offset)
|
local function transform_pool_init(has_resize, offset)
|
||||||
local nthread = torch.getnumthreads()
|
local nthread = torch.getnumthreads()
|
||||||
if (settings.thread > 0) then
|
if (settings.thread > 0) then
|
||||||
nthread = settings.thread
|
nthread = settings.thread
|
||||||
end
|
end
|
||||||
|
g_mutex = threads.Mutex()
|
||||||
|
g_mutex_id = g_mutex:id()
|
||||||
g_transform_pool = threads.Threads(
|
g_transform_pool = threads.Threads(
|
||||||
nthread,
|
nthread,
|
||||||
function(threadid)
|
function(threadid)
|
||||||
|
@ -56,10 +60,13 @@ local function transform_pool_init(has_resize, offset)
|
||||||
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
|
||||||
require 'nn'
|
require 'nn'
|
||||||
require 'cunn'
|
require 'cunn'
|
||||||
|
local threads = require 'threads'
|
||||||
|
|
||||||
local compression = require 'compression'
|
local compression = require 'compression'
|
||||||
local pairwise_transform = require 'pairwise_transform'
|
local pairwise_transform = require 'pairwise_transform'
|
||||||
|
|
||||||
function transformer(x, is_validation, n)
|
function transformer(x, is_validation, n)
|
||||||
|
local mutex = threads.Mutex(g_mutex_id)
|
||||||
local meta = {data = {}}
|
local meta = {data = {}}
|
||||||
local y = nil
|
local y = nil
|
||||||
if type(x) == "table" and type(x[2]) == "table" then
|
if type(x) == "table" and type(x[2]) == "table" then
|
||||||
|
@ -92,6 +99,7 @@ local function transform_pool_init(has_resize, offset)
|
||||||
end
|
end
|
||||||
if settings.method == "scale" then
|
if settings.method == "scale" then
|
||||||
local conf = tablex.update({
|
local conf = tablex.update({
|
||||||
|
mutex = mutex,
|
||||||
downsampling_filters = settings.downsampling_filters,
|
downsampling_filters = settings.downsampling_filters,
|
||||||
random_half_rate = settings.random_half_rate,
|
random_half_rate = settings.random_half_rate,
|
||||||
random_color_noise_rate = random_color_noise_rate,
|
random_color_noise_rate = random_color_noise_rate,
|
||||||
|
@ -114,6 +122,7 @@ local function transform_pool_init(has_resize, offset)
|
||||||
n, conf)
|
n, conf)
|
||||||
elseif settings.method == "noise" then
|
elseif settings.method == "noise" then
|
||||||
local conf = tablex.update({
|
local conf = tablex.update({
|
||||||
|
mutex = mutex,
|
||||||
random_half_rate = settings.random_half_rate,
|
random_half_rate = settings.random_half_rate,
|
||||||
random_color_noise_rate = random_color_noise_rate,
|
random_color_noise_rate = random_color_noise_rate,
|
||||||
random_overlay_rate = random_overlay_rate,
|
random_overlay_rate = random_overlay_rate,
|
||||||
|
@ -135,6 +144,7 @@ local function transform_pool_init(has_resize, offset)
|
||||||
n, conf)
|
n, conf)
|
||||||
elseif settings.method == "noise_scale" then
|
elseif settings.method == "noise_scale" then
|
||||||
local conf = tablex.update({
|
local conf = tablex.update({
|
||||||
|
mutex = mutex,
|
||||||
downsampling_filters = settings.downsampling_filters,
|
downsampling_filters = settings.downsampling_filters,
|
||||||
random_half_rate = settings.random_half_rate,
|
random_half_rate = settings.random_half_rate,
|
||||||
random_color_noise_rate = random_color_noise_rate,
|
random_color_noise_rate = random_color_noise_rate,
|
||||||
|
|
Loading…
Reference in a new issue