Add support for w2nn_valid_input_size
This commit is contained in:
parent
8051535fdc
commit
2d85f70415
|
@ -109,8 +109,27 @@ local function padding_params(x, model, block_size)
|
|||
p.pad_w2 = (w - input_offset) - p.x_w
|
||||
return p
|
||||
end
|
||||
local function find_valid_block_size(model, block_size)
|
||||
if model.w2nn_input_size ~= nil then
|
||||
return model.w2nn_input_size
|
||||
elseif model.w2nn_valid_input_size ~= nil then
|
||||
local best_size = 0
|
||||
local best_diff = 10000
|
||||
for i = 1, #model.w2nn_valid_input_size do
|
||||
local diff = math.abs(model.w2nn_valid_input_size[i] - block_size)
|
||||
if diff < best_diff then
|
||||
best_size = model.w2nn_valid_input_size[i]
|
||||
best_diff = diff
|
||||
end
|
||||
end
|
||||
assert(best_size > 0)
|
||||
return best_size
|
||||
else
|
||||
return block_size
|
||||
end
|
||||
end
|
||||
function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
||||
block_size = block_size or 128
|
||||
block_size = find_valid_block_size(model, block_size or 128)
|
||||
local p = padding_params(x, model, block_size)
|
||||
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
||||
x = x:cuda()
|
||||
|
@ -126,7 +145,7 @@ function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
||||
block_size = block_size or 128
|
||||
block_size = find_valid_block_size(model, block_size or 128)
|
||||
local x_lanczos
|
||||
if reconstruct.has_resize(model) then
|
||||
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
|
||||
|
@ -153,7 +172,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
||||
block_size = block_size or 128
|
||||
block_size = find_valid_block_size(model, block_size or 128)
|
||||
local p = padding_params(x, model, block_size)
|
||||
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
||||
if p.x_w * p.x_h > 2048*2048 then
|
||||
|
@ -168,7 +187,7 @@ function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
||||
block_size = block_size or 128
|
||||
block_size = find_valid_block_size(model, block_size or 128)
|
||||
if not reconstruct.has_resize(model) then
|
||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
||||
end
|
||||
|
@ -186,9 +205,6 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.image(model, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
|
@ -211,9 +227,6 @@ function reconstruct.image(model, x, block_size)
|
|||
return x
|
||||
end
|
||||
function reconstruct.scale(model, scale, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
|
|
|
@ -169,6 +169,17 @@ local function ReLU(backend)
|
|||
end
|
||||
srcnn.ReLU = ReLU
|
||||
|
||||
local function Sigmoid(backend)
|
||||
if backend == "cunn" then
|
||||
return nn.Sigmoid(true)
|
||||
elseif backend == "cudnn" then
|
||||
return cudnn.Sigmoid(true)
|
||||
else
|
||||
error("unsupported backend:" .. backend)
|
||||
end
|
||||
end
|
||||
srcnn.ReLU = ReLU
|
||||
|
||||
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||
if backend == "cunn" then
|
||||
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||
|
@ -821,11 +832,11 @@ function srcnn.upcunet_v2(backend, ch)
|
|||
local se_fac = 4
|
||||
local con = nn.ConcatTable(2)
|
||||
local attention = nn.Sequential()
|
||||
attention:add(nn.SpatialAveragePooling(4, 4, 4, 4))
|
||||
attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4))
|
||||
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
|
||||
attention:add(nn.ReLU(true))
|
||||
attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
|
||||
attention:add(nn.Sigmoid(true))
|
||||
attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
|
||||
attention:add(nn.SpatialUpSamplingNearest(4, 4))
|
||||
con:add(nn.Identity())
|
||||
con:add(attention)
|
||||
|
@ -872,7 +883,8 @@ function srcnn.upcunet_v2(backend, ch)
|
|||
model.w2nn_scale_factor = 2
|
||||
model.w2nn_channels = ch
|
||||
model.w2nn_resize = true
|
||||
model.w2nn_valid_input_size = {76,92,108,140,156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
|
||||
-- {76,92,108,140} are also valid size but it is too small
|
||||
model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
|
||||
|
||||
return model
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue