Add to json formt model; Fix json format;
This commit is contained in:
parent
cba96e36fb
commit
0b1a13d9c0
|
@ -5,13 +5,19 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa
|
||||||
require 'w2nn'
|
require 'w2nn'
|
||||||
local cjson = require "cjson"
|
local cjson = require "cjson"
|
||||||
|
|
||||||
local function meta_data(model)
|
local function meta_data(model, model_path)
|
||||||
local meta = {}
|
local meta = {}
|
||||||
for k, v in pairs(model) do
|
for k, v in pairs(model) do
|
||||||
if k:match("w2nn_") then
|
if k:match("w2nn_") then
|
||||||
meta[k:gsub("w2nn_", "")] = v
|
meta[k:gsub("w2nn_", "")] = v
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
modtime = file.modified_time(model_path)
|
||||||
|
utc_date = Date('utc')
|
||||||
|
utc_date:set(modtime)
|
||||||
|
meta["created_at"] = tostring(utc_date)
|
||||||
|
|
||||||
return meta
|
return meta
|
||||||
end
|
end
|
||||||
local function includes(s, a)
|
local function includes(s, a)
|
||||||
|
@ -33,7 +39,9 @@ end
|
||||||
local function export_weight(jmodules, seq)
|
local function export_weight(jmodules, seq)
|
||||||
local targets = {"nn.SpatialConvolutionMM",
|
local targets = {"nn.SpatialConvolutionMM",
|
||||||
"cudnn.SpatialConvolution",
|
"cudnn.SpatialConvolution",
|
||||||
|
"cudnn.SpatialDilatedConvolution",
|
||||||
"nn.SpatialFullConvolution",
|
"nn.SpatialFullConvolution",
|
||||||
|
"nn.SpatialDilatedConvolution",
|
||||||
"cudnn.SpatialFullConvolution"
|
"cudnn.SpatialFullConvolution"
|
||||||
}
|
}
|
||||||
for k = 1, #seq.modules do
|
for k = 1, #seq.modules do
|
||||||
|
@ -56,25 +64,27 @@ local function export_weight(jmodules, seq)
|
||||||
dW = mod.dW,
|
dW = mod.dW,
|
||||||
padW = mod.padW,
|
padW = mod.padW,
|
||||||
padH = mod.padH,
|
padH = mod.padH,
|
||||||
|
dilationW = mod.dilationW,
|
||||||
|
dilationH = mod.dilationH,
|
||||||
nInputPlane = mod.nInputPlane,
|
nInputPlane = mod.nInputPlane,
|
||||||
nOutputPlane = mod.nOutputPlane,
|
nOutputPlane = mod.nOutputPlane,
|
||||||
bias = torch.totable(get_bias(mod)),
|
bias = torch.totable(get_bias(mod)),
|
||||||
weight = weight
|
weight = weight
|
||||||
}
|
}
|
||||||
if first_layer then
|
|
||||||
first_layer = false
|
|
||||||
jmod.model_config = model_config
|
|
||||||
end
|
|
||||||
table.insert(jmodules, jmod)
|
table.insert(jmodules, jmod)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function export(model, output)
|
local function export(model, model_path, output)
|
||||||
local jmodules = {}
|
local jmodules = {}
|
||||||
local model_config = meta_data(model)
|
local model_config = meta_data(model, model_path)
|
||||||
local first_layer = true
|
local first_layer = true
|
||||||
|
|
||||||
|
print(model_config)
|
||||||
|
print(model)
|
||||||
|
|
||||||
export_weight(jmodules, model)
|
export_weight(jmodules, model)
|
||||||
|
jmodules[1]["model_config"] = model_config
|
||||||
|
|
||||||
local fp = io.open(output, "w")
|
local fp = io.open(output, "w")
|
||||||
if not fp then
|
if not fp then
|
||||||
|
@ -98,4 +108,4 @@ if not path.isfile(opt.i) then
|
||||||
os.exit(-1)
|
os.exit(-1)
|
||||||
end
|
end
|
||||||
local model = torch.load(opt.i, opt.iformat)
|
local model = torch.load(opt.i, opt.iformat)
|
||||||
export(model, opt.o)
|
export(model, opt.i, opt.o)
|
||||||
|
|
Loading…
Reference in a new issue