From 2d85f7041549ed0b495330d32f131bb5db765c62 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sat, 27 Oct 2018 07:59:56 +0000 Subject: [PATCH] Add support for w2nn_valid_input_size --- lib/reconstruct.lua | 33 +++++++++++++++++++++++---------- lib/srcnn.lua | 18 +++++++++++++++--- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 6c37f15..75344c8 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -109,8 +109,27 @@ local function padding_params(x, model, block_size) p.pad_w2 = (w - input_offset) - p.x_w return p end +local function find_valid_block_size(model, block_size) + if model.w2nn_input_size ~= nil then + return model.w2nn_input_size + elseif model.w2nn_valid_input_size ~= nil then + local best_size = 0 + local best_diff = 10000 + for i = 1, #model.w2nn_valid_input_size do + local diff = math.abs(model.w2nn_valid_input_size[i] - block_size) + if diff < best_diff then + best_size = model.w2nn_valid_input_size[i] + best_diff = diff + end + end + assert(best_size > 0) + return best_size + else + return block_size + end +end function reconstruct.image_y(model, x, offset, block_size, batch_size) - block_size = block_size or 128 + block_size = find_valid_block_size(model, block_size or 128) local p = padding_params(x, model, block_size) x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2) x = x:cuda() @@ -126,7 +145,7 @@ function reconstruct.image_y(model, x, offset, block_size, batch_size) return output end function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size) - block_size = block_size or 128 + block_size = find_valid_block_size(model, block_size or 128) local x_lanczos if reconstruct.has_resize(model) then x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos") @@ -153,7 +172,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size) return output end function reconstruct.image_rgb(model, x, offset, block_size, batch_size) - block_size = block_size or 128 + block_size = find_valid_block_size(model, block_size or 128) local p = padding_params(x, model, block_size) x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2) if p.x_w * p.x_h > 2048*2048 then @@ -168,7 +187,7 @@ function reconstruct.image_rgb(model, x, offset, block_size, batch_size) return output end function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size) - block_size = block_size or 128 + block_size = find_valid_block_size(model, block_size or 128) if not reconstruct.has_resize(model) then x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box") end @@ -186,9 +205,6 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size) return output end function reconstruct.image(model, x, block_size) - if model.w2nn_input_size then - block_size = model.w2nn_input_size - end local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -211,9 +227,6 @@ function reconstruct.image(model, x, block_size) return x end function reconstruct.scale(model, scale, x, block_size) - if model.w2nn_input_size then - block_size = model.w2nn_input_size - end local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) diff --git a/lib/srcnn.lua b/lib/srcnn.lua index 7f4b1f6..68f0b14 100644 --- a/lib/srcnn.lua +++ b/lib/srcnn.lua @@ -169,6 +169,17 @@ local function ReLU(backend) end srcnn.ReLU = ReLU +local function Sigmoid(backend) + if backend == "cunn" then + return nn.Sigmoid(true) + elseif backend == "cudnn" then + return cudnn.Sigmoid(true) + else + error("unsupported backend:" .. backend) + end +end +srcnn.ReLU = ReLU + local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH) if backend == "cunn" then return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH) @@ -821,11 +832,11 @@ function srcnn.upcunet_v2(backend, ch) local se_fac = 4 local con = nn.ConcatTable(2) local attention = nn.Sequential() - attention:add(nn.SpatialAveragePooling(4, 4, 4, 4)) + attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4)) attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0)) attention:add(nn.ReLU(true)) attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0)) - attention:add(nn.Sigmoid(true)) + attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid attention:add(nn.SpatialUpSamplingNearest(4, 4)) con:add(nn.Identity()) con:add(attention) @@ -872,7 +883,8 @@ function srcnn.upcunet_v2(backend, ch) model.w2nn_scale_factor = 2 model.w2nn_channels = ch model.w2nn_resize = true - model.w2nn_valid_input_size = {76,92,108,140,156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508} + -- {76,92,108,140} are also valid size but it is too small + model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508} return model end