Support Linear module
This commit is contained in:
parent
3c46906cb7
commit
929c7f85a9
|
@ -37,19 +37,29 @@ local function get_bias(mod)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
local function export_weight(jmodules, seq)
|
local function export_weight(jmodules, seq)
|
||||||
local targets = {"nn.SpatialConvolutionMM",
|
local convolutions = {"nn.SpatialConvolutionMM",
|
||||||
"cudnn.SpatialConvolution",
|
"cudnn.SpatialConvolution",
|
||||||
"cudnn.SpatialDilatedConvolution",
|
"cudnn.SpatialDilatedConvolution",
|
||||||
"nn.SpatialFullConvolution",
|
"nn.SpatialFullConvolution",
|
||||||
"nn.SpatialDilatedConvolution",
|
"nn.SpatialDilatedConvolution",
|
||||||
"cudnn.SpatialFullConvolution"
|
"cudnn.SpatialFullConvolution"
|
||||||
}
|
}
|
||||||
for k = 1, #seq.modules do
|
for k = 1, #seq.modules do
|
||||||
local mod = seq.modules[k]
|
local mod = seq.modules[k]
|
||||||
local name = torch.typename(mod)
|
local name = torch.typename(mod)
|
||||||
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
if name == "nn.Sequential" or name == "nn.ConcatTable" then
|
||||||
export_weight(jmodules, mod)
|
export_weight(jmodules, mod)
|
||||||
elseif includes(name, targets) then
|
elseif name == "nn.Linear" then
|
||||||
|
local weight = torch.totable(mod.weight:float())
|
||||||
|
local jmod = {
|
||||||
|
class_name = name,
|
||||||
|
nInputPlane = mod.weight:size(2),
|
||||||
|
nOutputPlane = mod.weight:size(1),
|
||||||
|
bias = torch.totable(get_bias(mod)),
|
||||||
|
weight = weight
|
||||||
|
}
|
||||||
|
table.insert(jmodules, jmod)
|
||||||
|
elseif includes(name, convolutions) then
|
||||||
local weight = mod.weight:float()
|
local weight = mod.weight:float()
|
||||||
if name:match("FullConvolution") then
|
if name:match("FullConvolution") then
|
||||||
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))
|
||||||
|
|
Loading…
Reference in a new issue