Support Linear module
This commit is contained in:
parent
3c46906cb7
commit
929c7f85a9
|
@ -37,7 +37,7 @@ local function get_bias(mod)
|
|||
end
|
||||
end
|
||||
local function export_weight(jmodules, seq)
|
||||
local targets = {"nn.SpatialConvolutionMM",
|
||||
local convolutions = {"nn.SpatialConvolutionMM",
|
||||
"cudnn.SpatialConvolution",
|
||||
"cudnn.SpatialDilatedConvolution",
|
||||
"nn.SpatialFullConvolution",
|
||||
|
@ -49,7 +49,17 @@ local function export_weight(jmodules, seq)
|
|||
local name = torch.typename(mod)
|
||||
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
||||
export_weight(jmodules, mod)
|
||||
elseif includes(name, targets) then
|
||||
elseif name == "nn.Linear" then
|
||||
local weight = torch.totable(mod.weight:float())
|
||||
local jmod = {
|
||||
class_name = name,
|
||||
nInputPlane = mod.weight:size(2),
|
||||
nOutputPlane = mod.weight:size(1),
|
||||
bias = torch.totable(get_bias(mod)),
|
||||
weight = weight
|
||||
}
|
||||
table.insert(jmodules, jmod)
|
||||
elseif includes(name, convolutions) then
|
||||
local weight = mod.weight:float()
|
||||
if name:match("FullConvolution") then
|
||||
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
||||
|
|
Loading…
Reference in a new issue