Add support for identity initializer for dilated convolution, and refactor
This commit is contained in:
parent
5fe6cf6ef8
commit
05bc54fa12
|
@ -4,34 +4,52 @@ require 'w2nn'
|
||||||
-- ref: http://arxiv.org/abs/1501.00092
|
-- ref: http://arxiv.org/abs/1501.00092
|
||||||
local srcnn = {}
|
local srcnn = {}
|
||||||
|
|
||||||
function nn.SpatialConvolutionMM:reset(stdv)
|
local function msra_filler(mod)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
local fin = mod.kW * mod.kH * mod.nInputPlane
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
local fout = mod.kW * mod.kH * mod.nOutputPlane
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
||||||
self.weight:normal(0, stdv)
|
mod.weight:normal(0, stdv)
|
||||||
self.bias:zero()
|
mod.bias:zero()
|
||||||
|
end
|
||||||
|
local function identity_filler(mod)
|
||||||
|
assert(mod.nInputPlane <= mod.nOutputPlane)
|
||||||
|
mod.weight:normal(0, 0.01)
|
||||||
|
mod.bias:zero()
|
||||||
|
local num_groups = mod.nInputPlane -- fixed
|
||||||
|
local filler_value = num_groups / mod.nOutputPlane
|
||||||
|
local in_group_size = math.floor(mod.nInputPlane / num_groups)
|
||||||
|
local out_group_size = math.floor(mod.nOutputPlane / num_groups)
|
||||||
|
local x = math.floor(mod.kW / 2)
|
||||||
|
local y = math.floor(mod.kH / 2)
|
||||||
|
for i = 0, num_groups - 1 do
|
||||||
|
for j = i * out_group_size, (i + 1) * out_group_size - 1 do
|
||||||
|
for k = i * in_group_size, (i + 1) * in_group_size - 1 do
|
||||||
|
mod.weight[j+1][k+1][y+1][x+1] = filler_value
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function nn.SpatialConvolutionMM:reset(stdv)
|
||||||
|
msra_filler(self)
|
||||||
end
|
end
|
||||||
function nn.SpatialFullConvolution:reset(stdv)
|
function nn.SpatialFullConvolution:reset(stdv)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
msra_filler(self)
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
||||||
self.weight:normal(0, stdv)
|
|
||||||
self.bias:zero()
|
|
||||||
end
|
end
|
||||||
|
function nn.SpatialDilatedConvolution:reset(stdv)
|
||||||
|
identity_filler(self)
|
||||||
|
end
|
||||||
|
|
||||||
if cudnn and cudnn.SpatialConvolution then
|
if cudnn and cudnn.SpatialConvolution then
|
||||||
function cudnn.SpatialConvolution:reset(stdv)
|
function cudnn.SpatialConvolution:reset(stdv)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
msra_filler(self)
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
|
||||||
self.weight:normal(0, stdv)
|
|
||||||
self.bias:zero()
|
|
||||||
end
|
end
|
||||||
function cudnn.SpatialFullConvolution:reset(stdv)
|
function cudnn.SpatialFullConvolution:reset(stdv)
|
||||||
local fin = self.kW * self.kH * self.nInputPlane
|
msra_filler(self)
|
||||||
local fout = self.kW * self.kH * self.nOutputPlane
|
end
|
||||||
stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout)))
|
if cudnn.SpatialDilatedConvolution then
|
||||||
self.weight:normal(0, stdv)
|
function cudnn.SpatialDilatedConvolution:reset(stdv)
|
||||||
self.bias:zero()
|
identity_filler(self)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
function nn.SpatialConvolutionMM:clearState()
|
function nn.SpatialConvolutionMM:clearState()
|
||||||
|
@ -162,6 +180,34 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||||
end
|
end
|
||||||
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
srcnn.SpatialMaxPooling = SpatialMaxPooling
|
||||||
|
|
||||||
|
local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH)
|
||||||
|
if backend == "cunn" then
|
||||||
|
return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||||
|
elseif backend == "cudnn" then
|
||||||
|
return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH)
|
||||||
|
else
|
||||||
|
error("unsupported backend:" .. backend)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
srcnn.SpatialAveragePooling = SpatialAveragePooling
|
||||||
|
|
||||||
|
local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
if backend == "cunn" then
|
||||||
|
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
elseif backend == "cudnn" then
|
||||||
|
if cudnn.SpatialDilatedConvolution then
|
||||||
|
-- cudnn v 6
|
||||||
|
return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
else
|
||||||
|
return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
error("unsupported backend:" .. backend)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
|
||||||
|
|
||||||
|
|
||||||
-- VGG style net(7 layers)
|
-- VGG style net(7 layers)
|
||||||
function srcnn.vgg_7(backend, ch)
|
function srcnn.vgg_7(backend, ch)
|
||||||
local model = nn.Sequential()
|
local model = nn.Sequential()
|
||||||
|
@ -555,6 +601,7 @@ function srcnn.create(model_name, backend, color)
|
||||||
error("unsupported model_name: " .. model_name)
|
error("unsupported model_name: " .. model_name)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
--[[
|
--[[
|
||||||
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
local model = srcnn.fcn_v1("cunn", 3):cuda()
|
||||||
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())
|
||||||
|
|
Loading…
Reference in a new issue