1
0
Fork 0
mirror of synced 2024-05-16 10:52:20 +12:00

Add support for identity initializer for dilated convolution, and refactor

This commit is contained in:
nagadomi 2017-04-21 08:54:39 +09:00
parent 5fe6cf6ef8
commit 05bc54fa12

View file

@ -4,34 +4,52 @@ require 'w2nn'
-- ref: http://arxiv.org/abs/1501.00092
local srcnn = {}
function nn.SpatialConvolutionMM:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
local function msra_filler(mod)
local fin = mod.kW * mod.kH * mod.nInputPlane
local fout = mod.kW * mod.kH * mod.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
mod.weight:normal(0, stdv)
mod.bias:zero()
end
local function identity_filler(mod)
assert(mod.nInputPlane <= mod.nOutputPlane)
mod.weight:normal(0, 0.01)
mod.bias:zero()
local num_groups = mod.nInputPlane -- fixed
local filler_value = num_groups / mod.nOutputPlane
local in_group_size = math.floor(mod.nInputPlane / num_groups)
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
local x = math.floor(mod.kW / 2)
local y = math.floor(mod.kH / 2)
for i = 0, num_groups - 1 do
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
mod.weight[j+1][k+1][y+1][x+1] = filler_value
end
end
end
end
function nn.SpatialConvolutionMM:reset(stdv)
msra_filler(self)
end
function nn.SpatialFullConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
msra_filler(self)
end
function nn.SpatialDilatedConvolution:reset(stdv)
identity_filler(self)
end
if cudnn and cudnn.SpatialConvolution then
function cudnn.SpatialConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
msra_filler(self)
end
function cudnn.SpatialFullConvolution:reset(stdv)
local fin = self.kW * self.kH * self.nInputPlane
local fout = self.kW * self.kH * self.nOutputPlane
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
self.weight:normal(0, stdv)
self.bias:zero()
msra_filler(self)
end
if cudnn.SpatialDilatedConvolution then
function cudnn.SpatialDilatedConvolution:reset(stdv)
identity_filler(self)
end
end
end
function nn.SpatialConvolutionMM:clearState()
@ -162,6 +180,34 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
end
srcnn.SpatialMaxPooling = SpatialMaxPooling
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
elseif backend == "cudnn" then
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
else
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialAveragePooling = SpatialAveragePooling
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
if backend == "cunn" then
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
elseif backend == "cudnn" then
if cudnn.SpatialDilatedConvolution then
-- cudnn v 6
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
else
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
end
else
error("unsupported backend:" .. backend)
end
end
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)
local model = nn.Sequential()
@ -555,6 +601,7 @@ function srcnn.create(model_name, backend, color)
error("unsupported model_name: " .. model_name)
end
end
--[[
local model = srcnn.fcn_v1("cunn", 3):cuda()
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())