diff --git a/tools/export_model.lua b/tools/export_model.lua index d95bd55..7e41d72 100644 --- a/tools/export_model.lua +++ b/tools/export_model.lua @@ -37,19 +37,29 @@ local function get_bias(mod) end end local function export_weight(jmodules, seq) - local targets = {"nn.SpatialConvolutionMM", - "cudnn.SpatialConvolution", - "cudnn.SpatialDilatedConvolution", - "nn.SpatialFullConvolution", - "nn.SpatialDilatedConvolution", - "cudnn.SpatialFullConvolution" + local convolutions = {"nn.SpatialConvolutionMM", + "cudnn.SpatialConvolution", + "cudnn.SpatialDilatedConvolution", + "nn.SpatialFullConvolution", + "nn.SpatialDilatedConvolution", + "cudnn.SpatialFullConvolution" } for k = 1, #seq.modules do local mod = seq.modules[k] local name = torch.typename(mod) if name == "nn.Sequential" or name == "nn.ConcatTable" then export_weight(jmodules, mod) - elseif includes(name, targets) then + elseif name == "nn.Linear" then + local weight = torch.totable(mod.weight:float()) + local jmod = { + class_name = name, + nInputPlane = mod.weight:size(2), + nOutputPlane = mod.weight:size(1), + bias = torch.totable(get_bias(mod)), + weight = weight + } + table.insert(jmodules, jmod) + elseif includes(name, convolutions) then local weight = mod.weight:float() if name:match("FullConvolution") then weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))