1
0
Fork 0
mirror of synced 2024-05-17 03:12:18 +12:00

Add support for w2nn_valid_input_size

This commit is contained in:
nagadomi 2018-10-27 07:59:56 +00:00
parent 8051535fdc
commit 2d85f70415
2 changed files with 38 additions and 13 deletions

View file

@ -109,8 +109,27 @@ local function padding_params(x, model, block_size)
p.pad_w2 = (w - input_offset) - p.x_w
return p
end
local function find_valid_block_size(model, block_size)
if model.w2nn_input_size ~= nil then
return model.w2nn_input_size
elseif model.w2nn_valid_input_size ~= nil then
local best_size = 0
local best_diff = 10000
for i = 1, #model.w2nn_valid_input_size do
local diff = math.abs(model.w2nn_valid_input_size[i] - block_size)
if diff < best_diff then
best_size = model.w2nn_valid_input_size[i]
best_diff = diff
end
end
assert(best_size > 0)
return best_size
else
return block_size
end
end
function reconstruct.image_y(model, x, offset, block_size, batch_size)
block_size = block_size or 128
block_size = find_valid_block_size(model, block_size or 128)
local p = padding_params(x, model, block_size)
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
x = x:cuda()
@ -126,7 +145,7 @@ function reconstruct.image_y(model, x, offset, block_size, batch_size)
return output
end
function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
block_size = block_size or 128
block_size = find_valid_block_size(model, block_size or 128)
local x_lanczos
if reconstruct.has_resize(model) then
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
@ -153,7 +172,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
return output
end
function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
block_size = block_size or 128
block_size = find_valid_block_size(model, block_size or 128)
local p = padding_params(x, model, block_size)
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
if p.x_w * p.x_h > 2048*2048 then
@ -168,7 +187,7 @@ function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
return output
end
function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
block_size = block_size or 128
block_size = find_valid_block_size(model, block_size or 128)
if not reconstruct.has_resize(model) then
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
end
@ -186,9 +205,6 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
return output
end
function reconstruct.image(model, x, block_size)
if model.w2nn_input_size then
block_size = model.w2nn_input_size
end
local i2rgb = false
if x:size(1) == 1 then
local new_x = torch.Tensor(3, x:size(2), x:size(3))
@ -211,9 +227,6 @@ function reconstruct.image(model, x, block_size)
return x
end
function reconstruct.scale(model, scale, x, block_size)
if model.w2nn_input_size then
block_size = model.w2nn_input_size
end
local i2rgb = false
if x:size(1) == 1 then
local new_x = torch.Tensor(3, x:size(2), x:size(3))

View file

@ -169,6 +169,17 @@ local function ReLU(backend)
end
srcnn.ReLU = ReLU
local function Sigmoid(backend)
if backend == "cunn" then
return nn.Sigmoid(true)
elseif backend == "cudnn" then
return cudnn.Sigmoid(true)
else
error("unsupported backend:" .. backend)
end
end
srcnn.ReLU = ReLU
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
if backend == "cunn" then
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
@ -821,11 +832,11 @@ function srcnn.upcunet_v2(backend, ch)
local se_fac = 4
local con = nn.ConcatTable(2)
local attention = nn.Sequential()
attention:add(nn.SpatialAveragePooling(4, 4, 4, 4))
attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4))
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
attention:add(nn.ReLU(true))
attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
attention:add(nn.Sigmoid(true))
attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
attention:add(nn.SpatialUpSamplingNearest(4, 4))
con:add(nn.Identity())
con:add(attention)
@ -872,7 +883,8 @@ function srcnn.upcunet_v2(backend, ch)
model.w2nn_scale_factor = 2
model.w2nn_channels = ch
model.w2nn_resize = true
model.w2nn_valid_input_size = {76,92,108,140,156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
-- {76,92,108,140} are also valid size but it is too small
model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
return model
end