1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00
This commit is contained in:
nagadomi 2018-10-18 19:49:10 +00:00
parent eea286059f
commit ef5aa1ccbb

View file

@ -984,7 +984,7 @@ function srcnn.cunet_v4(backend, ch)
return model
end
function srcnn.cunet_v5(backend, ch)
function srcnn.cunet_v6(backend, ch)
function unet_branch(insert, backend, n_input, n_output, depad)
local block = nn.Sequential()
local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling
@ -1000,22 +1000,37 @@ function srcnn.cunet_v5(backend, ch)
model:add(nn.CAddTable())
return model
end
function unet_conv(n_input, n_middle, n_output)
function unet_conv(n_input, n_middle, n_output, se)
local model = nn.Sequential()
model:add(SpatialConvolution(backend, n_input, n_middle, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
model:add(SpatialConvolution(backend, n_middle, n_output, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1, true))
if se then
-- Squeeze and Excitation Networks
local con = nn.ConcatTable(2)
local attention = nn.Sequential()
attention:add(nn.SpatialAdaptiveAveragePooling(1, 1)) -- global average pooling
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / 4), 1, 1, 1, 1, 0, 0))
attention:add(nn.ReLU(true))
attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0))
attention:add(nn.Sigmoid(true))
con:add(nn.Identity())
con:add(attention)
model:add(con)
model:add(w2nn.ScaleTable())
end
return model
end
-- Residual U-Net
function unet(backend, ch, deconv)
local block1 = unet_conv(128, 256, 128)
local block1 = unet_conv(128, 256, 128, true)
local block2 = nn.Sequential()
block2:add(unet_conv(64, 64, 128))
block2:add(unet_conv(64, 64, 128, true))
block2:add(unet_branch(block1, backend, 128, 128, 4))
block2:add(unet_conv(128, 64, 64))
block2:add(unet_conv(128, 64, 64, true))
local model = nn.Sequential()
model:add(unet_conv(ch, 32, 64))
model:add(unet_conv(ch, 32, 64, false))
model:add(unet_branch(block2, backend, 64, 64, 16))
model:add(SpatialConvolution(backend, 64, 64, 3, 3, 1, 1, 0, 0))
model:add(nn.LeakyReLU(0.1))
@ -1042,7 +1057,7 @@ function srcnn.cunet_v5(backend, ch)
model:add(aux_con)
model:add(w2nn.AuxiliaryLossTable(1)) -- auxiliary loss for single unet output
model.w2nn_arch_name = "cunet_v5"
model.w2nn_arch_name = "cunet_v6"
model.w2nn_offset = 60
model.w2nn_scale_factor = 2
model.w2nn_channels = ch
@ -1186,13 +1201,13 @@ print(model)
model:training()
print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()):size())
os.exit()
local model = srcnn.cunet_v5("cunn", 3):cuda()
local model = srcnn.cunet_v6("cunn", 3):cuda()
print(model)
model:training()
print(model:forward(torch.Tensor(1, 3, 144, 144):zero():cuda()))
os.exit()
--]]
return srcnn