1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

clean; Add upcunet_v3

This commit is contained in:
nagadomi 2018-10-28 16:03:52 +09:00
parent 17b8de2d36
commit b28b6172ca
4 changed files with 249 additions and 30 deletions

View file

@ -41,6 +41,6 @@ end
function AuxiliaryLossTable:clearState()
self.gradInput = {}
self.output_table = {}
self.output_tensor:set()
nn.utils.clear(self, 'output_tensor')
return parent:clearState()
end

View file

@ -33,7 +33,6 @@ function ScaleTable:updateGradInput(input, gradOutput)
return self.gradInput
end
function ScaleTable:clearState()
self.grad_tmp:set()
self.scale:set()
nn.utils.clear(self, {'grad_tmp','scale'})
return parent:clearState()
end

View file

@ -218,6 +218,13 @@ local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW,
end
srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution
local function GlobalAveragePooling(n_output)
local gap = nn.Sequential()
gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1))
gap:add(nn.View(-1, n_output, 1, 1))
return gap
end
srcnn.GlobalAveragePooling = GlobalAveragePooling
-- VGG style net(7 layers)
function srcnn.vgg_7(backend, ch)
@ -247,6 +254,7 @@ function srcnn.vgg_7(backend, ch)
return model
end
-- VGG style net(12 layers)
function srcnn.vgg_12(backend, ch)
local model = nn.Sequential()
@ -721,6 +729,38 @@ function srcnn.upconv_refine(backend, ch)
return model
end
-- I devised this arch because of the block size and global average pooling problem,
-- but SEBlock may possibly learn multi-scale input and no problems occur.
local function SpatialSEBlock(backend, ave_size, n_output, r)
local con = nn.ConcatTable(2)
local attention = nn.Sequential()
local n_mid = math.floor(n_output / r)
attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size))
attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
attention:add(nn.ReLU(true))
attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
attention:add(nn.Sigmoid(true))
attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size))
con:add(nn.Identity())
con:add(attention)
return con
end
-- Squeeze and Excitation Block
local function SEBlock(backend, n_output, r)
local con = nn.ConcatTable(2)
local attention = nn.Sequential()
local n_mid = math.floor(n_output / r)
attention:add(GlobalAveragePooling(n_output))
attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0))
attention:add(nn.ReLU(true))
attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0))
attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
con:add(nn.Identity())
con:add(attention)
return con
end
-- cascaded residual channel attention unet
function srcnn.upcunet(backend, ch)
function unet_branch(insert, backend, n_input, n_output, depad)
@ -744,17 +784,7 @@ function srcnn.upcunet(backend, ch)
model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
if se then
-- Squeeze and Excitation Networks
local con = nn.ConcatTable(2)
local attention = nn.Sequential()
attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0))
attention:add(nn.ReLU(true))
attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
attention:add(nn.Sigmoid(true))
con:add(nn.Identity())
con:add(attention)
model:add(con)
model:add(SEBlock(backend, n_output, 4))
model:add(w2nn.ScaleTable())
end
return model
@ -799,8 +829,10 @@ function srcnn.upcunet(backend, ch)
model.w2nn_scale_factor = 2
model.w2nn_channels = ch
model.w2nn_resize = true
-- 72, 128, 256 are valid
--model.w2nn_input_size = 128
model.w2nn_valid_input_size = {}
for i = 76, 512, 4 do
table.insert(model.w2nn_valid_input_size, i)
end
return model
end
@ -828,19 +860,7 @@ function srcnn.upcunet_v2(backend, ch)
model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
if se then
-- Spatial Squeeze and Excitation Networks
local se_fac = 4
local con = nn.ConcatTable(2)
local attention = nn.Sequential()
attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4))
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
attention:add(nn.ReLU(true))
attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
attention:add(nn.SpatialUpSamplingNearest(4, 4))
con:add(nn.Identity())
con:add(attention)
model:add(con)
model:add(SpatialSEBlock(backend, 4, n_output, 4))
model:add(nn.CMulTable())
end
return model
@ -888,11 +908,89 @@ function srcnn.upcunet_v2(backend, ch)
return model
end
-- cascaded residual channel attention unet
function srcnn.upcunet_v3(backend, ch)
local function unet_branch(insert, backend, n_input, n_output, depad)
local block = nn.Sequential()
local con = nn.ConcatTable(2)
local model = nn.Sequential()
block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling
block:add(nn.LeakyReLU(0.1, true))
block:add(insert)
block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling
block:add(nn.LeakyReLU(0.1, true))
con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad))
con:add(block)
model:add(con)
model:add(nn.CAddTable())
return model
end
local function unet_conv(n_input, n_middle, n_output, se)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
if se then
model:add(SEBlock(backend, n_output, 4))
model:add(w2nn.ScaleTable())
end
return model
end
-- Residual U-Net
local function unet(backend, ch, deconv)
local block1 = unet_conv(128, 256, 128, true)
local block2 = nn.Sequential()
block2:add(unet_conv(64, 64, 128, true))
block2:add(unet_branch(block1, backend, 128, 128, 4))
block2:add(unet_conv(128, 64, 64, true))
local model = nn.Sequential()
model:add(unet_conv(ch, 32, 64, false))
model:add(unet_branch(block2, backend, 64, 64, 16))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
if deconv then
model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3))
else
model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0))
end
return model
end
local model = nn.Sequential()
local con = nn.ConcatTable()
local aux_con = nn.ConcatTable()
-- 2 cascade
model:add(unet(backend, ch, true))
con:add(unet(backend, ch, false))
con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))
aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output
aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output
model:add(con)
model:add(aux_con)
model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
model.w2nn_arch_name = "upcunet_v3"
model.w2nn_offset = 60
model.w2nn_scale_factor = 2
model.w2nn_channels = ch
model.w2nn_resize = true
model.w2nn_valid_input_size = {}
for i = 76, 512, 4 do
table.insert(model.w2nn_valid_input_size, i)
end
return model
end
local function bench()
local sys = require 'sys'
cudnn.benchmark = true
local model = nil
local arch = {"upconv_7", "upcunet", "upcunet_v2"}
local arch = {"upconv_7", "upcunet", "upcunet_v3"}
local backend = "cudnn"
for k = 1, #arch do
model = srcnn[arch[k]](backend, 3):cuda()
@ -947,6 +1045,12 @@ print(model)
model:training()
print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
os.exit()
local model = srcnn.upcunet_v3("cunn", 3):cuda()
print(model)
model:training()
print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda()))
os.exit()
bench()
--]]
return srcnn

116
tools/find_unet.py Normal file
View file

@ -0,0 +1,116 @@
def find_unet_v2():
avg_pool=4
print_mod = False
check_mod = True
print("cascade")
for i in range(76, 512):
print("-- {}".format(i))
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
# deconv
s = s
s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 4 # conv3x3x2
if print_mod: print(s, s % 2, s % 4, s % 6, s % 8)
if check_mod and s % avg_pool != 0:
continue
s = s * 2 # up2x2
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
print("ok", i, s)
def find_unet():
check_mod = True
print_size = False
print("cascade")
for i in range(76, 512):
print_buf = []
s = i
# unet 1
s = s - 4 # conv3x3x2
if print_size: print("1/2", s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
# deconv
s = s - 2
s = s * 2 - 4
# unet 2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
if print_size: print("1/2",s)
if check_mod and s % 2 != 0:
continue
s = s / 2 # down2x2
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 4 # conv3x3x2
s = s * 2 # up2x2
if print_size: print("2x",s)
s = s - 2 # conv3x3
s = s - 2 # conv3x3 last
#if s % avg_pool != 0:
# continue
print("ok", i, s)
find_unet()