2015-10-28 19:30:47 +13:00
|
|
|
local function load_nn()
|
|
|
|
require 'torch'
|
|
|
|
require 'nn'
|
|
|
|
end
|
|
|
|
local function load_cunn()
|
|
|
|
require 'cutorch'
|
|
|
|
require 'cunn'
|
|
|
|
end
|
|
|
|
local function load_cudnn()
|
2016-07-28 01:55:56 +12:00
|
|
|
cudnn = require('cudnn')
|
2015-10-28 19:30:47 +13:00
|
|
|
end
|
|
|
|
if w2nn then
|
|
|
|
return w2nn
|
|
|
|
else
|
2016-07-28 01:55:56 +12:00
|
|
|
w2nn = {}
|
2016-05-06 05:20:54 +12:00
|
|
|
local state, ret = pcall(load_cunn)
|
|
|
|
if not state then
|
2016-05-07 00:48:53 +12:00
|
|
|
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
|
2016-05-06 05:20:54 +12:00
|
|
|
end
|
2015-10-28 19:30:47 +13:00
|
|
|
pcall(load_cudnn)
|
2016-06-12 19:33:50 +12:00
|
|
|
|
|
|
|
function w2nn.load_model(model_path, force_cudnn)
|
|
|
|
local model = torch.load(model_path, "ascii")
|
|
|
|
if force_cudnn then
|
|
|
|
model = cudnn.convert(model, cudnn)
|
|
|
|
end
|
|
|
|
model:cuda():evaluate()
|
|
|
|
return model
|
|
|
|
end
|
2015-10-28 19:30:47 +13:00
|
|
|
require 'LeakyReLU'
|
2015-11-08 09:44:14 +13:00
|
|
|
require 'ClippedWeightedHuberCriterion'
|
2016-05-27 19:56:38 +12:00
|
|
|
require 'ClippedMSECriterion'
|
2015-10-28 19:30:47 +13:00
|
|
|
return w2nn
|
|
|
|
end
|