1
0
Fork 0
mirror of synced 2024-06-29 11:41:29 +12:00
waifu2x/lib/w2nn.lua

37 lines
826 B
Lua
Raw Normal View History

2015-10-28 19:30:47 +13:00
local function load_nn()
require 'torch'
require 'nn'
end
local function load_cunn()
require 'cutorch'
require 'cunn'
end
local function load_cudnn()
2016-07-28 01:55:56 +12:00
cudnn = require('cudnn')
2015-10-28 19:30:47 +13:00
end
if w2nn then
return w2nn
else
2016-07-28 01:55:56 +12:00
w2nn = {}
local state, ret = pcall(load_cunn)
if not state then
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
end
2015-10-28 19:30:47 +13:00
pcall(load_cudnn)
function w2nn.load_model(model_path, force_cudnn)
local model = torch.load(model_path, "ascii")
if force_cudnn then
model = cudnn.convert(model, cudnn)
end
model:cuda():evaluate()
return model
end
2015-10-28 19:30:47 +13:00
require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion'
2016-05-27 19:56:38 +12:00
require 'ClippedMSECriterion'
require 'SSIMCriterion'
2016-09-15 00:23:29 +12:00
require 'InplaceClip01'
2015-10-28 19:30:47 +13:00
return w2nn
end