1
0
Fork 0
mirror of synced 2024-05-06 14:02:22 +12:00

Support Linear module

This commit is contained in:
nagadomi 2018-06-05 03:57:31 +09:00
parent 3c46906cb7
commit 929c7f85a9

View file

@ -37,19 +37,29 @@ local function get_bias(mod)
end
end
local function export_weight(jmodules, seq)
local targets = {"nn.SpatialConvolutionMM",
"cudnn.SpatialConvolution",
"cudnn.SpatialDilatedConvolution",
"nn.SpatialFullConvolution",
"nn.SpatialDilatedConvolution",
"cudnn.SpatialFullConvolution"
local convolutions = {"nn.SpatialConvolutionMM",
"cudnn.SpatialConvolution",
"cudnn.SpatialDilatedConvolution",
"nn.SpatialFullConvolution",
"nn.SpatialDilatedConvolution",
"cudnn.SpatialFullConvolution"
}
for k = 1, #seq.modules do
local mod = seq.modules[k]
local name = torch.typename(mod)
if name == "nn.Sequential" or name == "nn.ConcatTable" then
export_weight(jmodules, mod)
elseif includes(name, targets) then
elseif name == "nn.Linear" then
local weight = torch.totable(mod.weight:float())
local jmod = {
class_name = name,
nInputPlane = mod.weight:size(2),
nOutputPlane = mod.weight:size(1),
bias = torch.totable(get_bias(mod)),
weight = weight
}
table.insert(jmodules, jmod)
elseif includes(name, convolutions) then
local weight = mod.weight:float()
if name:match("FullConvolution") then
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))