From 05bc54fa1207c8b9c65a9a6a6f2d76a6a10cde27 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Fri, 21 Apr 2017 08:54:39 +0900 Subject: [PATCH] Add support for identity initializer for dilated convolution, and refactor --- lib/srcnn.lua | 87 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 20 deletions(-) diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 82fbf51..bf6ae14 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -4,34 +4,52 @@ require 'w2nn' -- ref: http://arxiv.org/abs/1501.00092 local srcnn = {} -function nn.SpatialConvolutionMM:reset(stdv) - local fin = self.kW * self.kH * self.nInputPlane - local fout = self.kW * self.kH * self.nOutputPlane +local function msra_filler(mod) + local fin = mod.kW * mod.kH * mod.nInputPlane + local fout = mod.kW * mod.kH * mod.nOutputPlane stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout))) - self.weight:normal(0, stdv) - self.bias:zero() + mod.weight:normal(0, stdv) + mod.bias:zero() +end +local function identity_filler(mod) + assert(mod.nInputPlane <= mod.nOutputPlane) + mod.weight:normal(0, 0.01) + mod.bias:zero() + local num_groups = mod.nInputPlane -- fixed + local filler_value = num_groups / mod.nOutputPlane + local in_group_size = math.floor(mod.nInputPlane / num_groups) + local out_group_size = math.floor(mod.nOutputPlane / num_groups) + local x = math.floor(mod.kW / 2) + local y = math.floor(mod.kH / 2) + for i = 0, num_groups - 1 do + for j = i * out_group_size, (i + 1) * out_group_size - 1 do + for k = i * in_group_size, (i + 1) * in_group_size - 1 do + mod.weight[j+1][k+1][y+1][x+1] = filler_value + end + end + end +end +function nn.SpatialConvolutionMM:reset(stdv) + msra_filler(self) end function nn.SpatialFullConvolution:reset(stdv) - local fin = self.kW * self.kH * self.nInputPlane - local fout = self.kW * self.kH * self.nOutputPlane - stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout))) - self.weight:normal(0, stdv) - self.bias:zero() + msra_filler(self) end +function nn.SpatialDilatedConvolution:reset(stdv) + identity_filler(self) +end + if cudnn and cudnn.SpatialConvolution then function cudnn.SpatialConvolution:reset(stdv) - local fin = self.kW * self.kH * self.nInputPlane - local fout = self.kW * self.kH * self.nOutputPlane - stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout))) - self.weight:normal(0, stdv) - self.bias:zero() + msra_filler(self) end function cudnn.SpatialFullConvolution:reset(stdv) - local fin = self.kW * self.kH * self.nInputPlane - local fout = self.kW * self.kH * self.nOutputPlane - stdv = math.sqrt(4 / ((1.0 + 0.1 * 0.1) * (fin + fout))) - self.weight:normal(0, stdv) - self.bias:zero() + msra_filler(self) + end + if cudnn.SpatialDilatedConvolution then + function cudnn.SpatialDilatedConvolution:reset(stdv) + identity_filler(self) + end end end function nn.SpatialConvolutionMM:clearState() @@ -162,6 +180,34 @@ local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH) end srcnn.SpatialMaxPooling = SpatialMaxPooling +local function SpatialAveragePooling(backend, kW, kH, dW, dH, padW, padH) + if backend == "cunn" then + return nn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH) + elseif backend == "cudnn" then + return cudnn.SpatialAveragePooling(kW, kH, dW, dH, padW, padH) + else + error("unsupported backend:" .. backend) + end +end +srcnn.SpatialAveragePooling = SpatialAveragePooling + +local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH) + if backend == "cunn" then + return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH) + elseif backend == "cudnn" then + if cudnn.SpatialDilatedConvolution then + -- cudnn v 6 + return cudnn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH) + else + return nn.SpatialDilatedConvolution(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, dilationW, dilationH) + end + else + error("unsupported backend:" .. backend) + end +end +srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution + + -- VGG style net(7 layers) function srcnn.vgg_7(backend, ch) local model = nn.Sequential() @@ -555,6 +601,7 @@ function srcnn.create(model_name, backend, color) error("unsupported model_name: " .. model_name) end end + --[[ local model = srcnn.fcn_v1("cunn", 3):cuda() print(model:forward(torch.Tensor(1, 3, 108, 108):zero():cuda()):size())