1
0
Fork 0
mirror of synced 2024-06-02 02:54:31 +12:00
waifu2x/lib/w2nn.lua

91 lines
2.4 KiB
Lua

local function load_nn()
require 'torch'
require 'nn'
end
local function load_cunn()
require 'cutorch'
require 'cunn'
end
local function load_cudnn()
cudnn = require('cudnn')
end
local function make_data_parallel_table(model, gpus)
if cudnn then
local fastest, benchmark = cudnn.fastest, cudnn.benchmark
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil
model = dpt:cuda()
else
local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require 'pl'
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'torch'
require 'cunn'
require 'w2nn'
end)
dpt.gradInput = nil
model = dpt:cuda()
end
return model
end
if w2nn then
return w2nn
else
w2nn = {}
local state, ret = pcall(load_cunn)
if not state then
error("Failed to load CUDA modules. Please check the CUDA Settings.\n---\n" .. ret)
end
pcall(load_cudnn)
function w2nn.load_model(model_path, force_cudnn, mode)
mode = mode or "ascii"
local model = torch.load(model_path, mode)
if force_cudnn then
model = cudnn.convert(model, cudnn)
end
model:cuda():evaluate()
return model
end
function w2nn.data_parallel(model, gpus)
if #gpus > 1 then
return make_data_parallel_table(model, gpus)
else
return model
end
end
require 'LeakyReLU'
require 'ClippedWeightedHuberCriterion'
require 'ClippedMSECriterion'
require 'SSIMCriterion'
require 'InplaceClip01'
require 'L1Criterion'
require 'ShakeShakeTable'
require 'PrintTable'
require 'Print'
require 'AuxiliaryLossTable'
require 'AuxiliaryLossCriterion'
require 'GradWeight'
require 'RandomBinaryConvolution'
require 'LBPCriterion'
require 'EdgeFilter'
require 'ScaleTable'
return w2nn
end