1
0
Fork 0
mirror of synced 2024-06-01 18:49:33 +12:00
waifu2x/tools/export_model.lua
2015-10-28 16:01:07 +09:00

25 lines
852 B
Lua

-- adapted from https://github.com/marcan/cl-waifu2x
local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)()
package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. package.path
require 'w2nn'
local cjson = require "cjson"
local model = torch.load(arg[1], "ascii")
local jmodules = {}
local modules = model:findModules("nn.SpatialConvolutionMM")
for i = 1, #modules, 1 do
local module = modules[i]
local jmod = {
kW = module.kW,
kH = module.kH,
nInputPlane = module.nInputPlane,
nOutputPlane = module.nOutputPlane,
bias = torch.totable(module.bias:float()),
weight = torch.totable(module.weight:float():reshape(module.nOutputPlane, module.nInputPlane, module.kW, module.kH))
}
table.insert(jmodules, jmod)
end
io.write(cjson.encode(jmodules))