1
0
Fork 0
mirror of synced 2024-06-02 19:14:30 +12:00
waifu2x/lib/w2nn.lua

86 lines
2.3 KiB
Lua
Raw Normal View History

2015-10-28 19:30:47 +13:00
local function load_nn()
require 'torch'
require 'nn'
end
local function load_cunn()
require 'cutorch'
require 'cunn'
end
local function load_cudnn()
2016-07-28 01:55:56 +12:00
cudnn = require('cudnn')
2015-10-28 19:30:47 +13:00
end
local function make_data_parallel_table(model, gpus)
if cudnn then
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
else
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
2015-10-28 19:30:47 +13:00
if w2nn then
return w2nn
else
2016-07-28 01:55:56 +12:00
w2nn = {}
local state, ret = pcall(load_cunn)
if not state then
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
end
2015-10-28 19:30:47 +13:00
pcall(load_cudnn)
2018-10-03 22:38:20 +13:00
function w2nn.load_model(model_path, force_cudnn, mode)
mode = mode or "ascii"
local model = torch.load(model_path, mode)
if force_cudnn then
model = cudnn.convert(model, cudnn)
end
model:cuda():evaluate()
return model
end
function w2nn.data_parallel(model, gpus)
if #gpus > 1 then
return make_data_parallel_table(model, gpus)
else
return model
end
end
2015-10-28 19:30:47 +13:00
require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion'
2016-05-27 19:56:38 +12:00
require 'ClippedMSECriterion'
require 'SSIMCriterion'
2016-09-15 00:23:29 +12:00
require 'InplaceClip01'
require 'L1Criterion'
2017-04-13 20:35:32 +12:00
require 'ShakeShakeTable'
2018-08-29 08:47:45 +12:00
require 'PrintTable'
require 'Print'
2018-10-03 22:35:38 +13:00
require 'AuxiliaryLossTable'
require 'AuxiliaryLossCriterion'
2015-10-28 19:30:47 +13:00
return w2nn
end
2018-10-03 22:35:38 +13:00