no confidence change for #128
This commit is contained in:
parent
8739f5487f
commit
382d493514
|
@ -7,18 +7,17 @@ local function load_cunn()
|
||||||
require 'cunn'
|
require 'cunn'
|
||||||
end
|
end
|
||||||
local function load_cudnn()
|
local function load_cudnn()
|
||||||
require 'cudnn'
|
cudnn = require('cudnn')
|
||||||
cudnn.benchmark = true
|
|
||||||
end
|
end
|
||||||
if w2nn then
|
if w2nn then
|
||||||
return w2nn
|
return w2nn
|
||||||
else
|
else
|
||||||
|
w2nn = {}
|
||||||
local state, ret = pcall(load_cunn)
|
local state, ret = pcall(load_cunn)
|
||||||
if not state then
|
if not state then
|
||||||
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
|
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
|
||||||
end
|
end
|
||||||
pcall(load_cudnn)
|
pcall(load_cudnn)
|
||||||
w2nn = {}
|
|
||||||
|
|
||||||
function w2nn.load_model(model_path, force_cudnn)
|
function w2nn.load_model(model_path, force_cudnn)
|
||||||
local model = torch.load(model_path, "ascii")
|
local model = torch.load(model_path, "ascii")
|
||||||
|
|
Loading…
Reference in a new issue