1
0
Fork 0
mirror of synced 2024-05-19 12:22:20 +12:00

Support Linear module

This commit is contained in:
nagadomi 2018-06-05 03:57:31 +09:00
parent 3c46906cb7
commit 929c7f85a9

View file

@ -37,19 +37,29 @@ local function get_bias(mod)
end end
end end
local function export_weight(jmodules, seq) local function export_weight(jmodules, seq)
local targets = {"nn.SpatialConvolutionMM", local convolutions = {"nn.SpatialConvolutionMM",
"cudnn.SpatialConvolution", "cudnn.SpatialConvolution",
"cudnn.SpatialDilatedConvolution", "cudnn.SpatialDilatedConvolution",
"nn.SpatialFullConvolution", "nn.SpatialFullConvolution",
"nn.SpatialDilatedConvolution", "nn.SpatialDilatedConvolution",
"cudnn.SpatialFullConvolution" "cudnn.SpatialFullConvolution"
} }
for k = 1, #seq.modules do for k = 1, #seq.modules do
local mod = seq.modules[k] local mod = seq.modules[k]
local name = torch.typename(mod) local name = torch.typename(mod)
if name == "nn.Sequential" or name == "nn.ConcatTable" then if name == "nn.Sequential" or name == "nn.ConcatTable" then
export_weight(jmodules, mod) export_weight(jmodules, mod)
elseif includes(name, targets) then elseif name == "nn.Linear" then
local weight = torch.totable(mod.weight:float())
local jmod = {
class_name = name,
nInputPlane = mod.weight:size(2),
nOutputPlane = mod.weight:size(1),
bias = torch.totable(get_bias(mod)),
weight = weight
}
table.insert(jmodules, jmod)
elseif includes(name, convolutions) then
local weight = mod.weight:float() local weight = mod.weight:float()
if name:match("FullConvolution") then if name:match("FullConvolution") then
weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW)) weight = torch.totable(weight:reshape(mod.nInputPlane, mod.nOutputPlane, mod.kH, mod.kW))