From 6ba6cfe1ffd483d1e98e710e5a773af3067b1aeb Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 1 Apr 2017 03:55:09 +0900 Subject: [PATCH 1/2] Fix gpu option --- waifu2x.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/waifu2x.lua b/waifu2x.lua index 8e48937..05c164e 100644 --- a/waifu2x.lua +++ b/waifu2x.lua @@ -276,6 +276,7 @@ local function waifu2x() if opt.thread > 0 then torch.setnumthreads(opt.thread) end + cutorch.setDevice(opt.gpu) if cudnn then cudnn.fastest = true if opt.l:len() > 0 then @@ -293,6 +294,5 @@ local function waifu2x() else convert_frames(opt) end - cutorch.setDevice(opt.gpu) end waifu2x() From b7e116de542a1650cd67a965bdcaf052f200425b Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 10 Apr 2017 20:20:17 +0900 Subject: [PATCH 2/2] Add support for multi GPU training (data parallel) train.lua -gpu 1,3,4 When use multi GPU mode, nccl.torch is required. --- lib/settings.lua | 15 ++++++++++++--- lib/w2nn.lua | 41 +++++++++++++++++++++++++++++++++++++++++ train.lua | 14 ++++++-------- 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/lib/settings.lua b/lib/settings.lua index fb0b890..892be9f 100644 --- a/lib/settings.lua +++ b/lib/settings.lua @@ -18,7 +18,6 @@ local cmd = torch.CmdLine() cmd:text() cmd:text("waifu2x-training") cmd:text("Options:") -cmd:option("-gpu", -1, 'GPU Device ID') cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)') cmd:option("-data_dir", "./data", 'path to data directory') cmd:option("-backend", "cunn", '(cunn|cudnn)') @@ -74,7 +73,7 @@ cmd:option("-oracle_drop_rate", 0.5, '') cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))') cmd:option("-resume", "", 'resume model file') cmd:option("-name", "user", 'model name for user method') -cmd:option("-gpu", 1, 'Device ID') +cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)') cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)') cmd:option("-update_criterion", "mse", 'mse|loss') @@ -168,6 +167,16 @@ end settings.images = string.format("%s/images.t7", settings.data_dir) settings.image_list = string.format("%s/image_list.txt", settings.data_dir) -cutorch.setDevice(opt.gpu) +if settings.gpu:len() > 0 then + local gpus = {} + local gpu_string = utils.split(settings.gpu, ",") + for i = 1, #gpu_string do + table.insert(gpus, tonumber(gpu_string[i])) + end + settings.gpu = gpus +else + settings.gpu = {1} +end +cutorch.setDevice(settings.gpu[1]) return settings diff --git a/lib/w2nn.lua b/lib/w2nn.lua index a7f93e4..5a9c727 100644 --- a/lib/w2nn.lua +++ b/lib/w2nn.lua @@ -9,6 +9,40 @@ end local function load_cudnn() cudnn = require('cudnn') end +local function make_data_parallel_table(model, gpus) + if cudnn then + local fastest, benchmark = cudnn.fastest, cudnn.benchmark + local dpt = nn.DataParallelTable(1, true, true) + :add(model, gpus) + :threads(function() + require 'pl' + local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() + package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path + require 'torch' + require 'cunn' + require 'w2nn' + local cudnn = require 'cudnn' + cudnn.fastest, cudnn.benchmark = fastest, benchmark + end) + dpt.gradInput = nil + model = dpt:cuda() + else + local dpt = nn.DataParallelTable(1, true, true) + :add(model, gpus) + :threads(function() + require 'pl' + local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() + package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path + require 'torch' + require 'cunn' + require 'w2nn' + end) + dpt.gradInput = nil + model = dpt:cuda() + end + return model +end + if w2nn then return w2nn else @@ -27,6 +61,13 @@ else model:cuda():evaluate() return model end + function w2nn.data_parallel(model, gpus) + if #gpus > 1 then + return make_data_parallel_table(model, gpus) + else + return model + end + end require 'LeakyReLU' require 'ClippedWeightedHuberCriterion' require 'ClippedMSECriterion' diff --git a/train.lua b/train.lua index eb28eed..aa8137d 100644 --- a/train.lua +++ b/train.lua @@ -480,8 +480,9 @@ local function train() ch, settings.crop_size, settings.crop_size) end local instance_loss = nil + local pmodel = w2nn.data_parallel(model, settings.gpu) for epoch = 1, settings.epoch do - model:training() + pmodel:training() print("# " .. epoch) if adam_config.learningRate then print("learning rate: " .. adam_config.learningRate) @@ -519,13 +520,13 @@ local function train() instance_loss = torch.Tensor(x:size(1)):zero() for i = 1, settings.inner_epoch do - model:training() - local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config) + pmodel:training() + local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config) instance_loss:copy(il) print(train_score) - model:evaluate() + pmodel:evaluate() print("# validation") - local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize) + local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize) table.insert(hist_train, train_score.loss) table.insert(hist_valid, score.loss) if settings.plot then @@ -593,9 +594,6 @@ local function train() end end end -if settings.gpu > 0 then - cutorch.setDevice(settings.gpu) -end torch.manualSeed(settings.seed) cutorch.manualSeed(settings.seed) print(settings)