1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Merge branch 'dev'

This commit is contained in:
nagadomi 2017-04-10 20:22:37 +09:00
commit f54dd37848
4 changed files with 60 additions and 12 deletions

View file

@ -18,7 +18,6 @@ local cmd = torch.CmdLine()
cmd:text()
cmd:text("waifu2x-training")
cmd:text("Options:")
cmd:option("-gpu", -1, 'GPU Device ID')
cmd:option("-seed", 11, 'RNG seed (note: it only able to reproduce the training results with `-thread 1`)')
cmd:option("-data_dir", "./data", 'path to data directory')
cmd:option("-backend", "cunn", '(cunn|cudnn)')
@ -74,7 +73,7 @@ cmd:option("-oracle_drop_rate", 0.5, '')
cmd:option("-learning_rate_decay", 3.0e-7, 'learning rate decay (learning_rate * 1/(1+num_of_data*patches*epoch))')
cmd:option("-resume", "", 'resume model file')
cmd:option("-name", "user", 'model name for user method')
cmd:option("-gpu", 1, 'Device ID')
cmd:option("-gpu", "", 'GPU Device ID or ID lists (comma seprated)')
cmd:option("-loss", "huber", 'loss function (huber|l1|mse|bce)')
cmd:option("-update_criterion", "mse", 'mse|loss')
cmd:option("-padding", 0, 'replication padding size')
@ -173,6 +172,16 @@ end
settings.images = string.format("%s/images.t7", settings.data_dir)
settings.image_list = string.format("%s/image_list.txt", settings.data_dir)
cutorch.setDevice(opt.gpu)
if settings.gpu:len() > 0 then
local gpus = {}
local gpu_string = utils.split(settings.gpu, ",")
for i = 1, #gpu_string do
table.insert(gpus, tonumber(gpu_string[i]))
end
settings.gpu = gpus
else
settings.gpu = {1}
end
cutorch.setDevice(settings.gpu[1])
return settings

View file

@ -9,6 +9,40 @@ end
local function load_cudnn()
cudnn = require('cudnn')
end
local function make_data_parallel_table(model, gpus)
if cudnn then
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
else
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
if w2nn then
return w2nn
else
@ -27,6 +61,13 @@ else
model:cuda():evaluate()
return model
end
function w2nn.data_parallel(model, gpus)
if #gpus > 1 then
return make_data_parallel_table(model, gpus)
else
return model
end
end
require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion'
require 'ClippedMSECriterion'

View file

@ -480,8 +480,9 @@ local function train()
ch, settings.crop_size, settings.crop_size)
end
local instance_loss = nil
local pmodel = w2nn.data_parallel(model, settings.gpu)
for epoch = 1, settings.epoch do
model:training()
pmodel:training()
print("# " .. epoch)
if adam_config.learningRate then
print("learning rate: " .. adam_config.learningRate)
@ -519,13 +520,13 @@ local function train()
instance_loss = torch.Tensor(x:size(1)):zero()
for i = 1, settings.inner_epoch do
model:training()
local train_score, il = minibatch_adam(model, criterion, eval_metric, x, y, adam_config)
pmodel:training()
local train_score, il = minibatch_adam(pmodel, criterion, eval_metric, x, y, adam_config)
instance_loss:copy(il)
print(train_score)
model:evaluate()
pmodel:evaluate()
print("# validation")
local score = validate(model, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
local score = validate(pmodel, criterion, eval_metric, valid_xy, adam_config.xBatchSize)
table.insert(hist_train, train_score.loss)
table.insert(hist_valid, score.loss)
if settings.plot then
@ -593,9 +594,6 @@ local function train()
end
end
end
if settings.gpu > 0 then
cutorch.setDevice(settings.gpu)
end
torch.manualSeed(settings.seed)
cutorch.manualSeed(settings.seed)
print(settings)

View file

@ -276,6 +276,7 @@ local function waifu2x()
if opt.thread > 0 then
torch.setnumthreads(opt.thread)
end
cutorch.setDevice(opt.gpu)
if cudnn then
cudnn.fastest = true
if opt.l:len() > 0 then
@ -293,6 +294,5 @@ local function waifu2x()
else
convert_frames(opt)
end
cutorch.setDevice(opt.gpu)
end
waifu2x()