diff --git a/tools/export_model.lua b/tools/export_model.lua index e8d5819..d95bd55 100644 --- a/tools/export_model.lua +++ b/tools/export_model.lua @@ -5,13 +5,19 @@ package.path = path.join(path.dirname(__FILE__), "..", "lib", "?.lua;") .. packa require 'w2nn' local cjson = require "cjson" -local function meta_data(model) +local function meta_data(model, model_path) local meta = {} for k, v in pairs(model) do if k:match("w2nn_") then meta[k:gsub("w2nn_", "")] = v end end + + modtime = file.modified_time(model_path) + utc_date = Date('utc') + utc_date:set(modtime) + meta["created_at"] = tostring(utc_date) + return meta end local function includes(s, a) @@ -33,7 +39,9 @@ end local function export_weight(jmodules, seq) local targets = {"nn.SpatialConvolutionMM", "cudnn.SpatialConvolution", + "cudnn.SpatialDilatedConvolution", "nn.SpatialFullConvolution", + "nn.SpatialDilatedConvolution", "cudnn.SpatialFullConvolution" } for k = 1, #seq.modules do @@ -56,25 +64,27 @@ local function export_weight(jmodules, seq) dW = mod.dW, padW = mod.padW, padH = mod.padH, + dilationW = mod.dilationW, + dilationH = mod.dilationH, nInputPlane = mod.nInputPlane, nOutputPlane = mod.nOutputPlane, bias = torch.totable(get_bias(mod)), weight = weight } - if first_layer then - first_layer = false - jmod.model_config = model_config - end table.insert(jmodules, jmod) end end end -local function export(model, output) +local function export(model, model_path, output) local jmodules = {} - local model_config = meta_data(model) + local model_config = meta_data(model, model_path) local first_layer = true + print(model_config) + print(model) + export_weight(jmodules, model) + jmodules[1]["model_config"] = model_config local fp = io.open(output, "w") if not fp then @@ -98,4 +108,4 @@ if not path.isfile(opt.i) then os.exit(-1) end local model = torch.load(opt.i, opt.iformat) -export(model, opt.o) +export(model, opt.i, opt.o)