Support Linear module
This commit is contained in:
parent
3c46906cb7
commit
929c7f85a9
|
@ -37,19 +37,29 @@ local function get_bias(mod)
|
|||
end
|
||||
end
|
||||
local function export_weight(jmodules, seq)
|
||||
local targets = {"nn.SpatialConvolutionMM",
|
||||
"cudnn.SpatialConvolution",
|
||||
"cudnn.SpatialDilatedConvolution",
|
||||
"nn.SpatialFullConvolution",
|
||||
"nn.SpatialDilatedConvolution",
|
||||
"cudnn.SpatialFullConvolution"
|
||||
local convolutions = {"nn.SpatialConvolutionMM",
|
||||
"cudnn.SpatialConvolution",
|
||||
"cudnn.SpatialDilatedConvolution",
|
||||
"nn.SpatialFullConvolution",
|
||||
"nn.SpatialDilatedConvolution",
|
||||
"cudnn.SpatialFullConvolution"
|
||||
}
|
||||
for k = 1, #seq.modules do
|
||||
local mod = seq.modules[k]
|
||||
local name = torch.typename(mod)
|
||||
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
||||
export_weight(jmodules, mod)
|
||||
elseif includes(name, targets) then
|
||||
elseif name == "nn.Linear" then
|
||||
local weight = torch.totable(mod.weight:float())
|
||||
local jmod = {
|
||||
class_name = name,
|
||||
nInputPlane = mod.weight:size(2),
|
||||
nOutputPlane = mod.weight:size(1),
|
||||
bias = torch.totable(get_bias(mod)),
|
||||
weight = weight
|
||||
}
|
||||
table.insert(jmodules, jmod)
|
||||
elseif includes(name, convolutions) then
|
||||
local weight = mod.weight:float()
|
||||
if name:match("FullConvolution") then
|
||||
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
||||
|
|
Loading…
Reference in a new issue