diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 9c2fe19..24a9407 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -255,77 +255,6 @@ function srcnn.vgg_7(backend, ch) return model end --- VGG style net(12 layers) -function srcnn.vgg_12(backend, ch) - local model = nn.Sequential() - model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) - model:add(w2nn.InplaceClip01()) - model:add(nn.View(-1):setNumInputDims(3)) - - model.w2nn_arch_name = "vgg_12" - model.w2nn_offset = 12 - model.w2nn_scale_factor = 1 - model.w2nn_resize = false - model.w2nn_channels = ch - --model:cuda() - --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) - - return model -end - --- Dilated Convolution (7 layers) -function srcnn.dilated_7(backend, ch) - local model = nn.Sequential() - model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.SpatialDilatedConvolution(64, 64, 3, 3, 1, 1, 0, 0, 2, 2)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 4, 4)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0)) - model:add(w2nn.InplaceClip01()) - model:add(nn.View(-1):setNumInputDims(3)) - - model.w2nn_arch_name = "dilated_7" - model.w2nn_offset = 12 - model.w2nn_scale_factor = 1 - model.w2nn_resize = false - model.w2nn_channels = ch - - --model:cuda() - --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) - - return model -end - -- Upconvolution function srcnn.upconv_7(backend, ch) local model = nn.Sequential() @@ -387,121 +316,6 @@ function srcnn.upconv_7l(backend, ch) return model end --- layerwise linear blending with skip connections --- Note: PSNR: upconv_7 < skiplb_7 < upconv_7l -function srcnn.skiplb_7(backend, ch) - local function skip(backend, i, o) - local con = nn.Concat(2) - local conv = nn.Sequential() - conv:add(SpatialConvolution(backend, i, o, 3, 3, 1, 1, 1, 1)) - conv:add(nn.LeakyReLU(0.1, true)) - - -- depth concat - con:add(conv) - con:add(nn.Identity()) -- skip - return con - end - local model = nn.Sequential() - model:add(skip(backend, ch, 16)) - model:add(skip(backend, 16+ch, 32)) - model:add(skip(backend, 32+16+ch, 64)) - model:add(skip(backend, 64+32+16+ch, 128)) - model:add(skip(backend, 128+64+32+16+ch, 128)) - model:add(skip(backend, 128+128+64+32+16+ch, 256)) - -- input of last layer = [all layerwise output(contains input layer)].flatten - model:add(SpatialFullConvolution(backend, 256+128+128+64+32+16+ch, ch, 4, 4, 2, 2, 3, 3):noBias()) -- linear blend - model:add(w2nn.InplaceClip01()) - model:add(nn.View(-1):setNumInputDims(3)) - model.w2nn_arch_name = "skiplb_7" - model.w2nn_offset = 14 - model.w2nn_scale_factor = 2 - model.w2nn_resize = true - model.w2nn_channels = ch - - --model:cuda() - --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) - - return model -end - --- dilated convolution + deconvolution --- Note: This model is not better than upconv_7. Maybe becuase of under-fitting. -function srcnn.dilated_upconv_7(backend, ch) - local model = nn.Sequential() - model:add(SpatialConvolution(backend, ch, 16, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 16, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.SpatialDilatedConvolution(32, 64, 3, 3, 1, 1, 0, 0, 2, 2)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.SpatialDilatedConvolution(64, 128, 3, 3, 1, 1, 0, 0, 2, 2)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(nn.SpatialDilatedConvolution(128, 128, 3, 3, 1, 1, 0, 0, 2, 2)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias()) - model:add(w2nn.InplaceClip01()) - model:add(nn.View(-1):setNumInputDims(3)) - - model.w2nn_arch_name = "dilated_upconv_7" - model.w2nn_offset = 20 - model.w2nn_scale_factor = 2 - model.w2nn_resize = true - model.w2nn_channels = ch - - --model:cuda() - --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) - - return model -end - --- ref: https://arxiv.org/abs/1609.04802 --- note: no batch-norm, no zero-paading -function srcnn.srresnet_2x(backend, ch) - local function resblock(backend) - local seq = nn.Sequential() - local con = nn.ConcatTable() - local conv = nn.Sequential() - conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - conv:add(ReLU(backend)) - conv:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) - conv:add(ReLU(backend)) - con:add(conv) - con:add(nn.SpatialZeroPadding(-2, -2, -2, -2)) -- identity + de-padding - seq:add(con) - seq:add(nn.CAddTable()) - return seq - end - local model = nn.Sequential() - --model:add(skip(backend, ch, 64 - ch)) - model:add(SpatialConvolution(backend, ch, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(resblock(backend)) - model:add(resblock(backend)) - model:add(resblock(backend)) - model:add(resblock(backend)) - model:add(resblock(backend)) - model:add(resblock(backend)) - model:add(SpatialFullConvolution(backend, 64, 64, 4, 4, 2, 2, 2, 2)) - model:add(ReLU(backend)) - model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0)) - - model:add(w2nn.InplaceClip01()) - --model:add(nn.View(-1):setNumInputDims(3)) - model.w2nn_arch_name = "srresnet_2x" - model.w2nn_offset = 28 - model.w2nn_scale_factor = 2 - model.w2nn_resize = true - model.w2nn_channels = ch - - --model:cuda() - --print(model:forward(torch.Tensor(32, ch, 92, 92):uniform():cuda()):size()) - - return model -end - --- large version of srresnet_2x. It's current best model but slow. function srcnn.resnet_14l(backend, ch) local function resblock(backend, i, o) local seq = nn.Sequential() @@ -601,150 +415,6 @@ function srcnn.fcn_v1(backend, ch) return model end -function srcnn.cupconv_14(backend, ch) - local function skip(backend, n_input, n_output, pad) - local con = nn.ConcatTable() - local conv = nn.Sequential() - local depad = nn.Sequential() - conv:add(nn.SelectTable(1)) - conv:add(SpatialConvolution(backend, n_input, n_output, 3, 3, 1, 1, 0, 0)) - conv:add(nn.LeakyReLU(0.1, true)) - con:add(conv) - con:add(nn.Identity()) - return con - end - local function concat(backend, n, ch, n_middle) - local con = nn.ConcatTable() - for i = 1, n do - local pad = i - 1 - if i == 1 then - con:add(nn.Sequential():add(nn.SelectTable(i))) - else - local seq = nn.Sequential() - seq:add(nn.SelectTable(i)) - if pad > 0 then - seq:add(nn.SpatialZeroPadding(-pad, -pad, -pad, -pad)) - end - if i == n then - --seq:add(SpatialConvolution(backend, ch, n_middle, 1, 1, 1, 1, 0, 0)) - else - seq:add(w2nn.GradWeight(0.025)) - seq:add(SpatialConvolution(backend, n_middle, n_middle, 1, 1, 1, 1, 0, 0)) - end - seq:add(nn.LeakyReLU(0.1, true)) - con:add(seq) - end - end - return nn.Sequential():add(con):add(nn.JoinTable(2)) - end - local model = nn.Sequential() - local m = 64 - local n = 14 - - model:add(nn.ConcatTable():add(nn.Identity())) - for i = 1, n - 1 do - if i == 1 then - model:add(skip(backend, ch, m)) - else - model:add(skip(backend, m, m)) - end - end - model:add(nn.FlattenTable()) - model:add(concat(backend, n, ch, m)) - model:add(SpatialFullConvolution(backend, m * (n - 1) + 3, ch, 4, 4, 2, 2, 3, 3):noBias()) - model:add(w2nn.InplaceClip01()) - model:add(nn.View(-1):setNumInputDims(3)) - - model.w2nn_arch_name = "cupconv_14" - model.w2nn_offset = 28 - model.w2nn_scale_factor = 2 - model.w2nn_channels = ch - model.w2nn_resize = true - - return model -end - -function srcnn.upconv_refine(backend, ch) - local function block(backend, ch) - local seq = nn.Sequential() - local con = nn.ConcatTable() - local res = nn.Sequential() - local base = nn.Sequential() - local refine = nn.Sequential() - local aux_con = nn.ConcatTable() - - res:add(w2nn.GradWeight(0.1)) - res:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) - res:add(nn.LeakyReLU(0.1, true)) - res:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) - res:add(nn.LeakyReLU(0.1, true)) - res:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) - res:add(nn.LeakyReLU(0.1, true)) - res:add(SpatialConvolution(backend, 128, ch, 3, 3, 1, 1, 0, 0):noBias()) - res:add(w2nn.InplaceClip01()) - res:add(nn.MulConstant(0.5)) - - con:add(res) - con:add(nn.Sequential():add(nn.SpatialZeroPadding(-4, -4, -4, -4)):add(nn.MulConstant(0.5))) - - -- main output - refine:add(nn.CAddTable()) -- averaging - refine:add(nn.View(-1):setNumInputDims(3)) - -- aux output - base:add(nn.SelectTable(2)) - base:add(nn.MulConstant(2)) -- revert mul 0.5 - base:add(nn.View(-1):setNumInputDims(3)) - - aux_con:add(refine) - aux_con:add(base) - - seq:add(con) - seq:add(aux_con) - seq:add(w2nn.AuxiliaryLossTable(1)) - return seq - end - local model = nn.Sequential() - model:add(SpatialConvolution(backend, ch, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 32, 32, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 32, 64, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 64, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, 128, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, 128, 256, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialFullConvolution(backend, 256, ch, 4, 4, 2, 2, 3, 3):noBias()) - model:add(w2nn.InplaceClip01()) - model:add(block(backend, ch)) - - model.w2nn_arch_name = "upconv_refine" - model.w2nn_offset = 18 - model.w2nn_scale_factor = 2 - model.w2nn_resize = true - model.w2nn_channels = ch - - return model -end - --- I devised this arch because of the block size and global average pooling problem, --- but SEBlock may possibly learn multi-scale input and no problems occur. -local function SpatialSEBlock(backend, ave_size, n_output, r) - local con = nn.ConcatTable(2) - local attention = nn.Sequential() - local n_mid = math.floor(n_output / r) - attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size)) - attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0)) - attention:add(nn.ReLU(true)) - attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0)) - attention:add(nn.Sigmoid(true)) - attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size)) - con:add(nn.Identity()) - con:add(attention) - return con -end -- Squeeze and Excitation Block local function SEBlock(backend, n_output, r) @@ -760,45 +430,64 @@ local function SEBlock(backend, n_output, r) con:add(attention) return con end +-- I devised this arch for the block size and global average pooling problem, +-- but SEBlock may possibly learn multi-scale input or just a normalization. No problems occur. +-- So this arch is not used. +local function SpatialSEBlock(backend, ave_size, n_output, r) + local con = nn.ConcatTable(2) + local attention = nn.Sequential() + local n_mid = math.floor(n_output / r) + attention:add(SpatialAveragePooling(backend, ave_size, ave_size, ave_size, ave_size)) + attention:add(SpatialConvolution(backend, n_output, n_mid, 1, 1, 1, 1, 0, 0)) + attention:add(nn.ReLU(true)) + attention:add(SpatialConvolution(backend, n_mid, n_output, 1, 1, 1, 1, 0, 0)) + attention:add(nn.Sigmoid(true)) + attention:add(nn.SpatialUpSamplingNearest(ave_size, ave_size)) + con:add(nn.Identity()) + con:add(attention) + return con +end +local function unet_branch(backend, insert, backend, n_input, n_output, depad) + local block = nn.Sequential() + local con = nn.ConcatTable(2) + local model = nn.Sequential() + + block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling + block:add(nn.LeakyReLU(0.1, true)) + block:add(insert) + block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling + block:add(nn.LeakyReLU(0.1, true)) + con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) + con:add(block) + model:add(con) + model:add(nn.CAddTable()) + return model +end +local function unet_conv(backend, n_input, n_middle, n_output, se) + local model = nn.Sequential() + model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) + model:add(nn.LeakyReLU(0.1, true)) + if se then + model:add(SEBlock(backend, n_output, 4)) + model:add(w2nn.ScaleTable()) + end + return model +end --- cascaded residual channel attention unet +-- Cascaded Residual Channel Attention U-Net function srcnn.upcunet(backend, ch) - function unet_branch(insert, backend, n_input, n_output, depad) - local block = nn.Sequential() - local con = nn.ConcatTable(2) - local model = nn.Sequential() - - block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling - block:add(insert) - block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling - con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) - con:add(block) - model:add(con) - model:add(nn.CAddTable()) - return model - end - function unet_conv(n_input, n_middle, n_output, se) - local model = nn.Sequential() - model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - if se then - model:add(SEBlock(backend, n_output, 4)) - model:add(w2nn.ScaleTable()) - end - return model - end -- Residual U-Net - function unet(backend, ch, deconv) - local block1 = unet_conv(128, 256, 128, true) + local function unet(backend, ch, deconv) + local block1 = unet_conv(backend, 128, 256, 128, true) local block2 = nn.Sequential() - block2:add(unet_conv(64, 64, 128, true)) - block2:add(unet_branch(block1, backend, 128, 128, 4)) - block2:add(unet_conv(128, 64, 64, true)) + block2:add(unet_conv(backend, 64, 64, 128, true)) + block2:add(unet_branch(backend, block1, backend, 128, 128, 4)) + block2:add(unet_conv(backend, 128, 64, 64, true)) local model = nn.Sequential() - model:add(unet_conv(ch, 32, 64, false)) - model:add(unet_branch(block2, backend, 64, 64, 16)) + model:add(unet_conv(backend, ch, 32, 64, false)) + model:add(unet_branch(backend, block2, backend, 64, 64, 16)) model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) if deconv then @@ -837,124 +526,22 @@ function srcnn.upcunet(backend, ch) return model end --- cascaded residual spatial channel attention unet -function srcnn.upcunet_v2(backend, ch) - function unet_branch(insert, backend, n_input, n_output, depad) - local block = nn.Sequential() - local con = nn.ConcatTable(2) - local model = nn.Sequential() - - block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling - block:add(insert) - block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling - con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) - con:add(block) - model:add(con) - model:add(nn.CAddTable()) - return model - end - function unet_conv(n_input, n_middle, n_output, se) - local model = nn.Sequential() - model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - if se then - model:add(SpatialSEBlock(backend, 4, n_output, 4)) - model:add(nn.CMulTable()) - end - return model - end - -- Residual U-Net - function unet(backend, in_ch, out_ch, deconv) - local block1 = unet_conv(128, 256, 128, true) +-- cunet for 1x +function srcnn.cunet(backend, ch) + local function unet(backend, ch) + local block1 = unet_conv(backend, 128, 256, 128, true) local block2 = nn.Sequential() - block2:add(unet_conv(64, 64, 128, true)) - block2:add(unet_branch(block1, backend, 128, 128, 4)) - block2:add(unet_conv(128, 64, 64, true)) + block2:add(unet_conv(backend, 64, 64, 128, true)) + block2:add(unet_branch(backend, block1, backend, 128, 128, 4)) + block2:add(unet_conv(backend, 128, 64, 64, true)) + local model = nn.Sequential() - model:add(unet_conv(in_ch, 32, 64, false)) - model:add(unet_branch(block2, backend, 64, 64, 16)) - if deconv then - model:add(SpatialFullConvolution(backend, 64, out_ch, 4, 4, 2, 2, 3, 3):noBias()) - else - model:add(SpatialConvolution(backend, 64, out_ch, 3, 3, 1, 1, 0, 0):noBias()) - end - return model - end - local model = nn.Sequential() - local con = nn.ConcatTable() - local aux_con = nn.ConcatTable() - - -- 2 cascade - model:add(unet(backend, ch, ch, true)) - con:add(nn.Sequential():add(unet(backend, ch, ch, false)):add(nn.SpatialZeroPadding(-1, -1, -1, -1))) -- -1 for odd output size - con:add(nn.SpatialZeroPadding(-20, -20, -20, -20)) - - aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output - aux_con:add(nn.Sequential():add(nn.SelectTable(2)):add(w2nn.InplaceClip01())) -- single unet output - - model:add(con) - model:add(aux_con) - model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output - - model.w2nn_arch_name = "upcunet_v2" - model.w2nn_offset = 58 - model.w2nn_scale_factor = 2 - model.w2nn_channels = ch - model.w2nn_resize = true - -- {76,92,108,140} are also valid size but it is too small - model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508} - - return model -end --- cascaded residual channel attention unet -function srcnn.upcunet_v3(backend, ch) - local function unet_branch(insert, backend, n_input, n_output, depad) - local block = nn.Sequential() - local con = nn.ConcatTable(2) - local model = nn.Sequential() - - block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling - block:add(nn.LeakyReLU(0.1, true)) - block:add(insert) - block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling - block:add(nn.LeakyReLU(0.1, true)) - con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) - con:add(block) - model:add(con) - model:add(nn.CAddTable()) - return model - end - local function unet_conv(n_input, n_middle, n_output, se) - local model = nn.Sequential() - model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0)) - model:add(nn.LeakyReLU(0.1, true)) - if se then - model:add(SEBlock(backend, n_output, 4)) - model:add(w2nn.ScaleTable()) - end - return model - end - -- Residual U-Net - local function unet(backend, ch, deconv) - local block1 = unet_conv(128, 256, 128, true) - local block2 = nn.Sequential() - block2:add(unet_conv(64, 64, 128, true)) - block2:add(unet_branch(block1, backend, 128, 128, 4)) - block2:add(unet_conv(128, 64, 64, true)) - local model = nn.Sequential() - model:add(unet_conv(ch, 32, 64, false)) - model:add(unet_branch(block2, backend, 64, 64, 16)) + model:add(unet_conv(backend, ch, 32, 64, false)) + model:add(unet_branch(backend, block2, backend, 64, 64, 16)) model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0)) model:add(nn.LeakyReLU(0.1)) - if deconv then - model:add(SpatialFullConvolution(backend, 64, ch, 4, 4, 2, 2, 3, 3)) - else - model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0)) - end + model:add(SpatialConvolution(backend, 64, ch, 3, 3, 1, 1, 0, 0)) + return model end local model = nn.Sequential() @@ -962,8 +549,8 @@ function srcnn.upcunet_v3(backend, ch) local aux_con = nn.ConcatTable() -- 2 cascade - model:add(unet(backend, ch, true)) - con:add(unet(backend, ch, false)) + model:add(unet(backend, ch)) + con:add(unet(backend, ch)) con:add(nn.SpatialZeroPadding(-20, -20, -20, -20)) aux_con:add(nn.Sequential():add(nn.CAddTable()):add(w2nn.InplaceClip01())) -- cascaded unet output @@ -973,13 +560,13 @@ function srcnn.upcunet_v3(backend, ch) model:add(aux_con) model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output - model.w2nn_arch_name = "upcunet_v3" - model.w2nn_offset = 60 - model.w2nn_scale_factor = 2 + model.w2nn_arch_name = "cunet" + model.w2nn_offset = 40 + model.w2nn_scale_factor = 1 model.w2nn_channels = ch - model.w2nn_resize = true + model.w2nn_resize = false model.w2nn_valid_input_size = {} - for i = 76, 512, 4 do + for i = 100, 512, 4 do table.insert(model.w2nn_valid_input_size, i) end @@ -990,7 +577,7 @@ local function bench() local sys = require 'sys' cudnn.benchmark = true local model = nil - local arch = {"upconv_7", "upcunet", "upcunet_v3"} + local arch = {"upconv_7", "upcunet","vgg_7", "cunet"} local backend = "cudnn" for k = 1, #arch do model = srcnn[arch[k]](backend, 3):cuda() @@ -1040,17 +627,8 @@ local model = srcnn.cunet_v3("cunn", 3):cuda() print(model) model:training() print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size()) -local model = srcnn.upcunet_v2("cunn", 3):cuda() -print(model) -model:training() -print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda())) -os.exit() -local model = srcnn.upcunet_v3("cunn", 3):cuda() -print(model) -model:training() -print(model:forward(torch.Tensor(1, 3, 76, 76):zero():cuda())) -os.exit() bench() +os.exit() --]] return srcnn diff --git a/tools/find_unet.py b/tools/find_unet.py index 515c1bc..6d2704e 100644 --- a/tools/find_unet.py +++ b/tools/find_unet.py @@ -1,4 +1,4 @@ -def find_unet_v2(): +def find_upcunet_v2(): avg_pool=4 print_mod = False check_mod = True @@ -57,12 +57,12 @@ def find_unet_v2(): # continue print("ok", i, s) -def find_unet(): +def find_upcunet(): check_mod = True print_size = False print("cascade") - for i in range(76, 512): + for i in range(72, 512): print_buf = [] s = i # unet 1 @@ -110,7 +110,61 @@ def find_unet(): s = s - 2 # conv3x3 last #if s % avg_pool != 0: # continue - print("ok", i, s) + print("ok", i, s, s/ i) -find_unet() +def find_cunet(): + check_mod = True + print_size = False + print("cascade") + + for i in range(72, 512): + print_buf = [] + s = i + # unet 1 + s = s - 4 # conv3x3x2 + if print_size: print("1/2", s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_size: print("1/2",s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + + s = s * 2 # up2x2 + if print_size: print("2x",s) + s = s - 4 # conv3x3x2 + s = s * 2 # up2x2 + if print_size: print("2x",s) + + s = s - 4 + #s = s * 2 - 4 + + # unet 2 + s = s - 4 # conv3x3x2 + if print_size: print("1/2",s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + if print_size: print("1/2",s) + if check_mod and s % 2 != 0: + continue + s = s / 2 # down2x2 + s = s - 4 # conv3x3x2 + s = s * 2 # up2x2 + if print_size: print("2x",s) + s = s - 4 # conv3x3x2 + s = s * 2 # up2x2 + if print_size: print("2x",s) + s = s - 2 # conv3x3 + s = s - 2 # conv3x3 last + #if s % avg_pool != 0: + # continue + print("ok", i, s, s / i) + +#find_upcunet() +find_cunet()