From 06246e0d78c4201ac005cf5df58761162082e39d Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 21 Oct 2018 00:56:39 +0000 Subject: [PATCH] refactor --- lib/srcnn.lua | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 191191c..5fa115a 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -987,16 +987,15 @@ end function srcnn.cunet_v6(backend, ch) function unet_branch(insert, backend, n_input, n_output, depad) local block = nn.Sequential() - local pooling = SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0) -- downsampling - --block:add(w2nn.Print()) - block:add(pooling) + local con = nn.ConcatTable(2) + local model = nn.Sequential() + + block:add(SpatialConvolution(backend, n_input, n_input, 2, 2, 2, 2, 0, 0))-- downsampling block:add(insert) block:add(SpatialFullConvolution(backend, n_output, n_output, 2, 2, 2, 2, 0, 0))-- upsampling - local parallel = nn.ConcatTable(2) - parallel:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) - parallel:add(block) - local model = nn.Sequential() - model:add(parallel) + con:add(nn.SpatialZeroPadding(-depad, -depad, -depad, -depad)) + con:add(block) + model:add(con) model:add(nn.CAddTable()) return model end @@ -1015,7 +1014,7 @@ function srcnn.cunet_v6(backend, ch) attention:add(nn.ReLU(true)) attention:add(SpatialConvolution(backend, math.floor(n_output / 4), n_output, 1, 1, 1, 1, 0, 0)) attention:add(nn.Sigmoid(true)) - con:add(nn.Identity()) + con:add(nn.Identity()) con:add(attention) model:add(con) model:add(w2nn.ScaleTable()) @@ -1046,7 +1045,6 @@ function srcnn.cunet_v6(backend, ch) local aux_con = nn.ConcatTable() model:add(unet(backend, ch, true)) - con:add(unet(backend, ch, false)) con:add(nn.SpatialZeroPadding(-20, -20, -20, -20))