Add support for multi GPU training (data parallel)
train.lua -gpu 1,3,4 When use multi GPU mode, nccl.torch is required.
This commit is contained in:
parent
6ba6cfe1ff
commit
b7e116de54
|
@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
|
|||
cmd:text()
|
||||
cmd:text("waifu2x-training")
|
||||
cmd:text("Options:")
|
||||
cmd:option("-gpu", -1, 'GPU Device ID')
|
||||
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
|
||||
cmd:option("-data_dir", "./data", 'path to data directory')
|
||||
cmd:option("-backend", "cunn", '(cunn|cudnn)')
|
||||
|
@ -74,7 +73,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
|
|||
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
|
||||
cmd:option("-resume", "", 'resume model file')
|
||||
cmd:option("-name", "user", 'model name for user method')
|
||||
cmd:option("-gpu", 1, 'Device ID')
|
||||
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
|
||||
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
|
||||
cmd:option("-update_criterion", "mse", 'mse|loss')
|
||||
|
||||
|
@ -168,6 +167,16 @@ end
|
|||
settings.images = string.format("%s/images.t7", settings.data_dir)
|
||||
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
|
||||
|
||||
cutorch.setDevice(opt.gpu)
|
||||
if settings.gpu:len() > 0 then
|
||||
local gpus = {}
|
||||
local gpu_string = utils.split(settings.gpu, ",")
|
||||
for i = 1, #gpu_string do
|
||||
table.insert(gpus, tonumber(gpu_string[i]))
|
||||
end
|
||||
settings.gpu = gpus
|
||||
else
|
||||
settings.gpu = {1}
|
||||
end
|
||||
cutorch.setDevice(settings.gpu[1])
|
||||
|
||||
return settings
|
||||
|
|
41
lib/w2nn.lua
41
lib/w2nn.lua
|
@ -9,6 +9,40 @@ end
|
|||
local function load_cudnn()
|
||||
cudnn = require('cudnn')
|
||||
end
|
||||
local function make_data_parallel_table(model, gpus)
|
||||
if cudnn then
|
||||
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
|
||||
local dpt = nn.DataParallelTable(1, true, true)
|
||||
:add(model, gpus)
|
||||
:threads(function()
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'torch'
|
||||
require 'cunn'
|
||||
require 'w2nn'
|
||||
local cudnn = require 'cudnn'
|
||||
cudnn.fastest, cudnn.benchmark = fastest, benchmark
|
||||
end)
|
||||
dpt.gradInput = nil
|
||||
model = dpt:cuda()
|
||||
else
|
||||
local dpt = nn.DataParallelTable(1, true, true)
|
||||
:add(model, gpus)
|
||||
:threads(function()
|
||||
require 'pl'
|
||||
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
||||
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
||||
require 'torch'
|
||||
require 'cunn'
|
||||
require 'w2nn'
|
||||
end)
|
||||
dpt.gradInput = nil
|
||||
model = dpt:cuda()
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
if w2nn then
|
||||
return w2nn
|
||||
else
|
||||
|
@ -27,6 +61,13 @@ else
|
|||
model:cuda():evaluate()
|
||||
return model
|
||||
end
|
||||
function w2nn.data_parallel(model, gpus)
|
||||
if #gpus > 1 then
|
||||
return make_data_parallel_table(model, gpus)
|
||||
else
|
||||
return model
|
||||
end
|
||||
end
|
||||
require 'LeakyReLU'
|
||||
require 'ClippedWeightedHuberCriterion'
|
||||
require 'ClippedMSECriterion'
|
||||
|
|
14
train.lua
14
train.lua
|
@ -480,8 +480,9 @@ local function train()
|
|||
ch, settings.crop_size, settings.crop_size)
|
||||
end
|
||||
local instance_loss = nil
|
||||
local pmodel = w2nn.data_parallel(model, settings.gpu)
|
||||
for epoch = 1, settings.epoch do
|
||||
model:training()
|
||||
pmodel:training()
|
||||
print("# " .. epoch)
|
||||
if adam_config.learningRate then
|
||||
print("learning rate: " .. adam_config.learningRate)
|
||||
|
@ -519,13 +520,13 @@ local function train()
|
|||
instance_loss = torch.Tensor(x:size(1)):zero()
|
||||
|
||||
for i = 1, settings.inner_epoch do
|
||||
model:training()
|
||||
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
|
||||
pmodel:training()
|
||||
local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
|
||||
instance_loss:copy(il)
|
||||
print(train_score)
|
||||
model:evaluate()
|
||||
pmodel:evaluate()
|
||||
print("# validation")
|
||||
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||
local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
|
||||
table.insert(hist_train, train_score.loss)
|
||||
table.insert(hist_valid, score.loss)
|
||||
if settings.plot then
|
||||
|
@ -593,9 +594,6 @@ local function train()
|
|||
end
|
||||
end
|
||||
end
|
||||
if settings.gpu > 0 then
|
||||
cutorch.setDevice(settings.gpu)
|
||||
end
|
||||
torch.manualSeed(settings.seed)
|
||||
cutorch.manualSeed(settings.seed)
|
||||
print(settings)
|
||||
|
|
Loading…
Reference in a new issue