Add to json formt model; Fix json format;
This commit is contained in:
parent
cba96e36fb
commit
0b1a13d9c0
|
@ -5,13 +5,19 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
|
|||
require 'w2nn'
|
||||
local cjson = require "cjson"
|
||||
|
||||
local function meta_data(model)
|
||||
local function meta_data(model, model_path)
|
||||
local meta = {}
|
||||
for k, v in pairs(model) do
|
||||
if k:match("w2nn_") then
|
||||
meta[k:gsub("w2nn_", "")] = v
|
||||
end
|
||||
end
|
||||
|
||||
modtime = file.modified_time(model_path)
|
||||
utc_date = Date('utc')
|
||||
utc_date:set(modtime)
|
||||
meta["created_at"] = tostring(utc_date)
|
||||
|
||||
return meta
|
||||
end
|
||||
local function includes(s, a)
|
||||
|
@ -33,7 +39,9 @@ end
|
|||
local function export_weight(jmodules, seq)
|
||||
local targets = {"nn.SpatialConvolutionMM",
|
||||
"cudnn.SpatialConvolution",
|
||||
"cudnn.SpatialDilatedConvolution",
|
||||
"nn.SpatialFullConvolution",
|
||||
"nn.SpatialDilatedConvolution",
|
||||
"cudnn.SpatialFullConvolution"
|
||||
}
|
||||
for k = 1, #seq.modules do
|
||||
|
@ -56,25 +64,27 @@ local function export_weight(jmodules, seq)
|
|||
dW = mod.dW,
|
||||
padW = mod.padW,
|
||||
padH = mod.padH,
|
||||
dilationW = mod.dilationW,
|
||||
dilationH = mod.dilationH,
|
||||
nInputPlane = mod.nInputPlane,
|
||||
nOutputPlane = mod.nOutputPlane,
|
||||
bias = torch.totable(get_bias(mod)),
|
||||
weight = weight
|
||||
}
|
||||
if first_layer then
|
||||
first_layer = false
|
||||
jmod.model_config = model_config
|
||||
end
|
||||
table.insert(jmodules, jmod)
|
||||
end
|
||||
end
|
||||
end
|
||||
local function export(model, output)
|
||||
local function export(model, model_path, output)
|
||||
local jmodules = {}
|
||||
local model_config = meta_data(model)
|
||||
local model_config = meta_data(model, model_path)
|
||||
local first_layer = true
|
||||
|
||||
print(model_config)
|
||||
print(model)
|
||||
|
||||
export_weight(jmodules, model)
|
||||
jmodules[1]["model_config"] = model_config
|
||||
|
||||
local fp = io.open(output, "w")
|
||||
if not fp then
|
||||
|
@ -98,4 +108,4 @@ if not path.isfile(opt.i) then
|
|||
os.exit(-1)
|
||||
end
|
||||
local model = torch.load(opt.i, opt.iformat)
|
||||
export(model, opt.o)
|
||||
export(model, opt.i, opt.o)
|
||||
|
|
Loading…
Reference in a new issue