1
0
Fork 0
mirror of synced 2024-06-26 10:10:49 +12:00
This commit is contained in:
nagadomi 2016-09-24 05:51:47 +09:00
parent 5a3d012f4e
commit a14e6acec3

View file

@ -43,11 +43,15 @@ local function split_data(x, test_size)
end
local g_transform_pool = nil
local g_mutex = nil
local g_mutex_id = nil
local function transform_pool_init(has_resize, offset)
local nthread = torch.getnumthreads()
if (settings.thread > 0) then
nthread = settings.thread
end
g_mutex = threads.Mutex()
g_mutex_id = g_mutex:id()
g_transform_pool = threads.Threads(
nthread,
function(threadid)
@ -56,10 +60,13 @@ local function transform_pool_init(has_resize, offset)
package.path = path.join(path.dirname(__FILE__), "lib", "?.lua;") .. package.path
require 'nn'
require 'cunn'
local threads = require 'threads'
local compression = require 'compression'
local pairwise_transform = require 'pairwise_transform'
function transformer(x, is_validation, n)
local mutex = threads.Mutex(g_mutex_id)
local meta = {data = {}}
local y = nil
if type(x) == "table" and type(x[2]) == "table" then
@ -92,6 +99,7 @@ local function transform_pool_init(has_resize, offset)
end
if settings.method == "scale" then
local conf = tablex.update({
mutex = mutex,
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
@ -114,6 +122,7 @@ local function transform_pool_init(has_resize, offset)
n, conf)
elseif settings.method == "noise" then
local conf = tablex.update({
mutex = mutex,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,
random_overlay_rate = random_overlay_rate,
@ -135,6 +144,7 @@ local function transform_pool_init(has_resize, offset)
n, conf)
elseif settings.method == "noise_scale" then
local conf = tablex.update({
mutex = mutex,
downsampling_filters = settings.downsampling_filters,
random_half_rate = settings.random_half_rate,
random_color_noise_rate = random_color_noise_rate,