Add support for identity initializer for dilated convolution, and refactor
This commit is contained in:
parent
5fe6cf6ef8
commit
05bc54fa12
|
@ -4,34 +4,52 @@ require 'w2nn'
|
|||
-- ref: http://arxiv.org/abs/1501.00092
|
||||
local srcnn = {}
|
||||
|
||||
function nn.SpatialConvolutionMM:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
local function msra_filler(mod)
|
||||
local fin = mod.kW * mod.kH * mod.nInputPlane
|
||||
local fout = mod.kW * mod.kH * mod.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
mod.weight:normal(0, stdv)
|
||||
mod.bias:zero()
|
||||
end
|
||||
local function identity_filler(mod)
|
||||
assert(mod.nInputPlane <= mod.nOutputPlane)
|
||||
mod.weight:normal(0, 0.01)
|
||||
mod.bias:zero()
|
||||
local num_groups = mod.nInputPlane -- fixed
|
||||
local filler_value = num_groups / mod.nOutputPlane
|
||||
local in_group_size = math.floor(mod.nInputPlane / num_groups)
|
||||
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
|
||||
local x = math.floor(mod.kW / 2)
|
||||
local y = math.floor(mod.kH / 2)
|
||||
for i = 0, num_groups - 1 do
|
||||
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
|
||||
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
|
||||
mod.weight[j+1][k+1][y+1][x+1] = filler_value
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
function nn.SpatialConvolutionMM:reset(stdv)
|
||||
msra_filler(self)
|
||||
end
|
||||
function nn.SpatialFullConvolution:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
msra_filler(self)
|
||||
end
|
||||
function nn.SpatialDilatedConvolution:reset(stdv)
|
||||
identity_filler(self)
|
||||
end
|
||||
|
||||
if cudnn and cudnn.SpatialConvolution then
|
||||
function cudnn.SpatialConvolution:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
msra_filler(self)
|
||||
end
|
||||
function cudnn.SpatialFullConvolution:reset(stdv)
|
||||
local fin = self.kW * self.kH * self.nInputPlane
|
||||
local fout = self.kW * self.kH * self.nOutputPlane
|
||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||
self.weight:normal(0, stdv)
|
||||
self.bias:zero()
|
||||
msra_filler(self)
|
||||
end
|
||||
if cudnn.SpatialDilatedConvolution then
|
||||
function cudnn.SpatialDilatedConvolution:reset(stdv)
|
||||
identity_filler(self)
|
||||
end
|
||||
end
|
||||
end
|
||||
function nn.SpatialConvolutionMM:clearState()
|
||||
|
@ -162,6 +180,34 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
|||
end
|
||||
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
||||
|
||||
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||
elseif backend == "cudnn" then
|
||||
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialAveragePooling = SpatialAveragePooling
|
||||
|
||||
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
elseif backend == "cudnn" then
|
||||
if cudnn.SpatialDilatedConvolution then
|
||||
-- cudnn v 6
|
||||
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
else
|
||||
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||
end
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
|
||||
|
||||
|
||||
-- VGG style net(7 layers)
|
||||
function srcnn.vgg_7(backend, ch)
|
||||
local model = nn.Sequential()
|
||||
|
@ -555,6 +601,7 @@ function srcnn.create(model_name, backend, color)
|
|||
error("unsupported model_name: " .. model_name)
|
||||
end
|
||||
end
|
||||
|
||||
--[[
|
||||
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
||||
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
||||
|
|
Loading…
Reference in a new issue