2015-10-28 19:30:47 +13:00
|
|
|
local function load_nn()
|
|
|
|
require 'torch'
|
|
|
|
require 'nn'
|
|
|
|
end
|
|
|
|
local function load_cunn()
|
|
|
|
require 'cutorch'
|
|
|
|
require 'cunn'
|
|
|
|
end
|
|
|
|
local function load_cudnn()
|
2016-07-28 01:55:56 +12:00
|
|
|
cudnn = require('cudnn')
|
2015-10-28 19:30:47 +13:00
|
|
|
end
|
2017-04-10 23:20:17 +12:00
|
|
|
local function make_data_parallel_table(model, gpus)
|
|
|
|
if cudnn then
|
|
|
|
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
|
|
|
|
local dpt = nn.DataParallelTable(1, true, true)
|
|
|
|
:add(model, gpus)
|
|
|
|
:threads(function()
|
|
|
|
require 'pl'
|
|
|
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
|
|
|
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
|
|
|
require 'torch'
|
|
|
|
require 'cunn'
|
|
|
|
require 'w2nn'
|
|
|
|
local cudnn = require 'cudnn'
|
|
|
|
cudnn.fastest, cudnn.benchmark = fastest, benchmark
|
|
|
|
end)
|
|
|
|
dpt.gradInput = nil
|
|
|
|
model = dpt:cuda()
|
|
|
|
else
|
|
|
|
local dpt = nn.DataParallelTable(1, true, true)
|
|
|
|
:add(model, gpus)
|
|
|
|
:threads(function()
|
|
|
|
require 'pl'
|
|
|
|
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
|
|
|
|
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
|
|
|
|
require 'torch'
|
|
|
|
require 'cunn'
|
|
|
|
require 'w2nn'
|
|
|
|
end)
|
|
|
|
dpt.gradInput = nil
|
|
|
|
model = dpt:cuda()
|
|
|
|
end
|
|
|
|
return model
|
|
|
|
end
|
|
|
|
|
2015-10-28 19:30:47 +13:00
|
|
|
if w2nn then
|
|
|
|
return w2nn
|
|
|
|
else
|
2016-07-28 01:55:56 +12:00
|
|
|
w2nn = {}
|
2016-05-06 05:20:54 +12:00
|
|
|
local state, ret = pcall(load_cunn)
|
|
|
|
if not state then
|
2016-05-07 00:48:53 +12:00
|
|
|
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
|
2016-05-06 05:20:54 +12:00
|
|
|
end
|
2015-10-28 19:30:47 +13:00
|
|
|
pcall(load_cudnn)
|
2016-06-12 19:33:50 +12:00
|
|
|
|
|
|
|
function w2nn.load_model(model_path, force_cudnn)
|
|
|
|
local model = torch.load(model_path, "ascii")
|
|
|
|
if force_cudnn then
|
|
|
|
model = cudnn.convert(model, cudnn)
|
|
|
|
end
|
|
|
|
model:cuda():evaluate()
|
|
|
|
return model
|
|
|
|
end
|
2017-04-10 23:20:17 +12:00
|
|
|
function w2nn.data_parallel(model, gpus)
|
|
|
|
if #gpus > 1 then
|
|
|
|
return make_data_parallel_table(model, gpus)
|
|
|
|
else
|
|
|
|
return model
|
|
|
|
end
|
|
|
|
end
|
2015-10-28 19:30:47 +13:00
|
|
|
require 'LeakyReLU'
|
2015-11-08 09:44:14 +13:00
|
|
|
require 'ClippedWeightedHuberCriterion'
|
2016-05-27 19:56:38 +12:00
|
|
|
require 'ClippedMSECriterion'
|
2016-10-03 02:22:50 +13:00
|
|
|
require 'SSIMCriterion'
|
2016-09-15 00:23:29 +12:00
|
|
|
require 'InplaceClip01'
|
2016-12-05 14:32:26 +13:00
|
|
|
require 'L1Criterion'
|
2017-04-13 20:35:32 +12:00
|
|
|
require 'ShakeShakeTable'
|
2015-10-28 19:30:47 +13:00
|
|
|
return w2nn
|
|
|
|
end
|