From b28b6172ca471f9a76786aeefb2e882c812f900d Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 28 Oct 2018 16:03:52 +0900 Subject: [PATCH] clean; Add upcunet_v3 --- lib/AuxiliaryLossTable.lua | 2 +- lib/ScaleTable.lua | 3 +- lib/srcnn.lua | 158 ++++++++++++++++++++++++++++++------- tools/find_unet.py | 116 +++++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 30 deletions(-) create mode 100644 tools/find_unet.py diff --git a/lib/AuxiliaryLossTable.lua b/lib/AuxiliaryLossTable.lua index cc1bffc..aa44d0b 100644 --- a/lib/AuxiliaryLossTable.lua +++ b/lib/AuxiliaryLossTable.lua @@ -41,6 +41,6 @@ end function AuxiliaryLossTable:clearState() self.gradInput = {} self.output_table = {} - self.output_tensor:set() + nn.utils.clear(self, 'output_tensor') return parent:clearState() end diff --git a/lib/ScaleTable.lua b/lib/ScaleTable.lua index f2b0ce9..2363162 100644 --- a/lib/ScaleTable.lua +++ b/lib/ScaleTable.lua @@ -33,7 +33,6 @@ function ScaleTable:updateGradInput(input, gradOutput) return self.gradInput end function ScaleTable:clearState() - self.grad_tmp:set() - self.scale:set() + nn.utils.clear(self, {'grad_tmp','scale'}) return parent:clearState() end diff --git a/lib/srcnn.lua b/lib/srcnn.lua index c064256..9c2fe19 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -218,6 +218,13 @@ local function SpatialDilatedConvolution(backend, nInputPlane, nOutputPlane, kW, end srcnn.SpatialDilatedConvolution = SpatialDilatedConvolution +local function GlobalAveragePooling(n_output) + local gap = nn.Sequential() + gap:add(nn.Mean(-1, -1)):add(nn.Mean(-1, -1)) + gap:add(nn.View(-1, n_output, 1, 1)) + return gap +end +srcnn.GlobalAveragePooling = GlobalAveragePooling -- VGG style net(7 layers) function srcnn.vgg_7(backend, ch) @@ -247,6 +254,7 @@ function srcnn.vgg_7(backend, ch) return model end + -- VGG style net(12 layers) function srcnn.vgg_12(backend, ch) local model = nn.Sequential() @@ -721,6 +729,38 @@ function srcnn.upconv_refine(backend, ch) return model end +-- I devised this arch because of the block size and global average pooling problem, +-- but SEBlock may possibly learn multi-scale input and no problems occur. +local function SpatialSEBlock(backend, ave_size, n_output, r) + local con = nn.ConcatTable(2) + local attention = nn.Sequential() + local n_mid = math.floor(n_output / r) + attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size)) + attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0)) + attention:add(nn.ReLU(true)) + attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0)) + attention:add(nn.Sigmoid(true)) + attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size)) + con:add(nn.Identity()) + con:add(attention) + return con +end + +-- Squeeze and Excitation Block +local function SEBlock(backend, n_output, r) + local con = nn.ConcatTable(2) + local attention = nn.Sequential() + local n_mid = math.floor(n_output / r) + attention:add(GlobalAveragePooling(n_output)) + attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0)) + attention:add(nn.ReLU(true)) + attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0)) + attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid + con:add(nn.Identity()) + con:add(attention) + return con +end + -- cascaded residual channel attention unet function srcnn.upcunet(backend, ch) function unet_branch(insert, backend, n_input, n_output, depad) @@ -744,17 +784,7 @@ function srcnn.upcunet(backend, ch) model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) if se then - -- Squeeze and Excitation Networks - local con = nn.ConcatTable(2) - local attention = nn.Sequential() - attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling - attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0)) - attention:add(nn.ReLU(true)) - attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0)) - attention:add(nn.Sigmoid(true)) - con:add(nn.Identity()) - con:add(attention) - model:add(con) + model:add(SEBlock(backend, n_output, 4)) model:add(w2nn.ScaleTable()) end return model @@ -799,8 +829,10 @@ function srcnn.upcunet(backend, ch) model.w2nn_scale_factor = 2 model.w2nn_channels = ch model.w2nn_resize = true - -- 72, 128, 256 are valid - --model.w2nn_input_size = 128 + model.w2nn_valid_input_size = {} + for i = 76, 512, 4 do + table.insert(model.w2nn_valid_input_size, i) + end return model end @@ -828,19 +860,7 @@ function srcnn.upcunet_v2(backend, ch) model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1, true)) if se then - -- Spatial Squeeze and Excitation Networks - local se_fac = 4 - local con = nn.ConcatTable(2) - local attention = nn.Sequential() - attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4)) - attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0)) - attention:add(nn.ReLU(true)) - attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0)) - attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid - attention:add(nn.SpatialUpSamplingNearest(4, 4)) - con:add(nn.Identity()) - con:add(attention) - model:add(con) + model:add(SpatialSEBlock(backend, 4, n_output, 4)) model:add(nn.CMulTable()) end return model @@ -888,11 +908,89 @@ function srcnn.upcunet_v2(backend, ch) return model end +-- cascaded residual channel attention unet +function srcnn.upcunet_v3(backend, ch) + local function unet_branch(insert, backend, n_input, n_output, depad) + local block = nn.Sequential() + local con = nn.ConcatTable(2) + local model = nn.Sequential() + + block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling + block:add(nn.LeakyReLU(0.1, true)) + block:add(insert) + block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling + block:add(nn.LeakyReLU(0.1, true)) + con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) + con:add(block) + model:add(con) + model:add(nn.CAddTable()) + return model + end + local function unet_conv(n_input, n_middle, n_output, se) + local model = nn.Sequential() + model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + if se then + model:add(SEBlock(backend, n_output, 4)) + model:add(w2nn.ScaleTable()) + end + return model + end + -- Residual U-Net + local function unet(backend, ch, deconv) + local block1 = unet_conv(128, 256, 128, true) + local block2 = nn.Sequential() + block2:add(unet_conv(64, 64, 128, true)) + block2:add(unet_branch(block1, backend, 128, 128, 4)) + block2:add(unet_conv(128, 64, 64, true)) + local model = nn.Sequential() + model:add(unet_conv(ch, 32, 64, false)) + model:add(unet_branch(block2, backend, 64, 64, 16)) + model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1)) + if deconv then + model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3)) + else + model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0)) + end + return model + end + local model = nn.Sequential() + local con = nn.ConcatTable() + local aux_con = nn.ConcatTable() + + -- 2 cascade + model:add(unet(backend, ch, true)) + con:add(unet(backend, ch, false)) + con:add(nn.SpatialZeroPadding(-20, -20, -20, -20)) + + aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output + aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output + + model:add(con) + model:add(aux_con) + model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output + + model.w2nn_arch_name = "upcunet_v3" + model.w2nn_offset = 60 + model.w2nn_scale_factor = 2 + model.w2nn_channels = ch + model.w2nn_resize = true + model.w2nn_valid_input_size = {} + for i = 76, 512, 4 do + table.insert(model.w2nn_valid_input_size, i) + end + + return model +end + local function bench() local sys = require 'sys' cudnn.benchmark = true local model = nil - local arch = {"upconv_7", "upcunet", "upcunet_v2"} + local arch = {"upconv_7", "upcunet", "upcunet_v3"} local backend = "cudnn" for k = 1, #arch do model = srcnn[arch[k]](backend, 3):cuda() @@ -947,6 +1045,12 @@ print(model) model:training() print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda())) os.exit() +local model = srcnn.upcunet_v3("cunn", 3):cuda() +print(model) +model:training() +print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda())) +os.exit() +bench() --]] return srcnn diff --git a/tools/find_unet.py b/tools/find_unet.py new file mode 100644 index 0000000..515c1bc --- /dev/null +++ b/tools/find_unet.py @@ -0,0 +1,116 @@ +def find_unet_v2(): + avg_pool=4 + print_mod = False + check_mod = True + print("cascade") + + for i in range(76, 512): + print("-- {}".format(i)) + print_buf = [] + s = i + # unet 1 + + s = s - 4 # conv3x3x2 + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_mod: print(s, s % 2, s % 4, s % 6, s % 8) + if check_mod and s % avg_pool != 0: + continue + + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + + if print_mod: print(s, s % 2, s % 4, s % 6, s % 8) + if check_mod and s % avg_pool != 0: + continue + s = s * 2 # up2x2 + s = s - 4 # conv3x3x2 + if print_mod: print(s, s % 2, s % 4, s % 6, s % 8) + if check_mod and s % avg_pool != 0: + continue + s = s * 2 # up2x2 + + # deconv + s = s + s = s * 2 - 4 + + # unet 2 + s = s - 4 # conv3x3x2 + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_mod: print(s, s % 2, s % 4, s % 6, s % 8) + if check_mod and s % avg_pool != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_mod: print(s, s % 2, s % 4, s % 6, s % 8) + if check_mod and s % avg_pool != 0: + continue + s = s * 2 # up2x2 + s = s - 4 # conv3x3x2 + if print_mod: print(s, s % 2, s % 4, s % 6, s % 8) + if check_mod and s % avg_pool != 0: + continue + s = s * 2 # up2x2 + s = s - 2 # conv3x3 last + #if s % avg_pool != 0: + # continue + print("ok", i, s) + +def find_unet(): + check_mod = True + print_size = False + print("cascade") + + for i in range(76, 512): + print_buf = [] + s = i + # unet 1 + + s = s - 4 # conv3x3x2 + if print_size: print("1/2", s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_size: print("1/2",s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + + s = s * 2 # up2x2 + if print_size: print("2x",s) + s = s - 4 # conv3x3x2 + s = s * 2 # up2x2 + if print_size: print("2x",s) + + # deconv + s = s - 2 + s = s * 2 - 4 + + # unet 2 + s = s - 4 # conv3x3x2 + if print_size: print("1/2",s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_size: print("1/2",s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + s = s * 2 # up2x2 + if print_size: print("2x",s) + s = s - 4 # conv3x3x2 + s = s * 2 # up2x2 + if print_size: print("2x",s) + s = s - 2 # conv3x3 + s = s - 2 # conv3x3 last + #if s % avg_pool != 0: + # continue + print("ok", i, s) + +find_unet() +