Add support for w2nn_valid_input_size
This commit is contained in:
parent
8051535fdc
commit
2d85f70415
|
@ -109,8 +109,27 @@ local function padding_params(x, model, block_size)
|
||||||
p.pad_w2 = (w - input_offset) - p.x_w
|
p.pad_w2 = (w - input_offset) - p.x_w
|
||||||
return p
|
return p
|
||||||
end
|
end
|
||||||
|
local function find_valid_block_size(model, block_size)
|
||||||
|
if model.w2nn_input_size ~= nil then
|
||||||
|
return model.w2nn_input_size
|
||||||
|
elseif model.w2nn_valid_input_size ~= nil then
|
||||||
|
local best_size = 0
|
||||||
|
local best_diff = 10000
|
||||||
|
for i = 1, #model.w2nn_valid_input_size do
|
||||||
|
local diff = math.abs(model.w2nn_valid_input_size[i] - block_size)
|
||||||
|
if diff < best_diff then
|
||||||
|
best_size = model.w2nn_valid_input_size[i]
|
||||||
|
best_diff = diff
|
||||||
|
end
|
||||||
|
end
|
||||||
|
assert(best_size > 0)
|
||||||
|
return best_size
|
||||||
|
else
|
||||||
|
return block_size
|
||||||
|
end
|
||||||
|
end
|
||||||
function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
||||||
block_size = block_size or 128
|
block_size = find_valid_block_size(model, block_size or 128)
|
||||||
local p = padding_params(x, model, block_size)
|
local p = padding_params(x, model, block_size)
|
||||||
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
||||||
x = x:cuda()
|
x = x:cuda()
|
||||||
|
@ -126,7 +145,7 @@ function reconstruct.image_y(model, x, offset, block_size, batch_size)
|
||||||
return output
|
return output
|
||||||
end
|
end
|
||||||
function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
||||||
block_size = block_size or 128
|
block_size = find_valid_block_size(model, block_size or 128)
|
||||||
local x_lanczos
|
local x_lanczos
|
||||||
if reconstruct.has_resize(model) then
|
if reconstruct.has_resize(model) then
|
||||||
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
|
x_lanczos = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Lanczos")
|
||||||
|
@ -153,7 +172,7 @@ function reconstruct.scale_y(model, scale, x, offset, block_size, batch_size)
|
||||||
return output
|
return output
|
||||||
end
|
end
|
||||||
function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
||||||
block_size = block_size or 128
|
block_size = find_valid_block_size(model, block_size or 128)
|
||||||
local p = padding_params(x, model, block_size)
|
local p = padding_params(x, model, block_size)
|
||||||
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
x = iproc.padding(x, p.pad_w1, p.pad_w2, p.pad_h1, p.pad_h2)
|
||||||
if p.x_w * p.x_h > 2048*2048 then
|
if p.x_w * p.x_h > 2048*2048 then
|
||||||
|
@ -168,7 +187,7 @@ function reconstruct.image_rgb(model, x, offset, block_size, batch_size)
|
||||||
return output
|
return output
|
||||||
end
|
end
|
||||||
function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
||||||
block_size = block_size or 128
|
block_size = find_valid_block_size(model, block_size or 128)
|
||||||
if not reconstruct.has_resize(model) then
|
if not reconstruct.has_resize(model) then
|
||||||
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
x = iproc.scale(x, x:size(3) * scale, x:size(2) * scale, "Box")
|
||||||
end
|
end
|
||||||
|
@ -186,9 +205,6 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
||||||
return output
|
return output
|
||||||
end
|
end
|
||||||
function reconstruct.image(model, x, block_size)
|
function reconstruct.image(model, x, block_size)
|
||||||
if model.w2nn_input_size then
|
|
||||||
block_size = model.w2nn_input_size
|
|
||||||
end
|
|
||||||
local i2rgb = false
|
local i2rgb = false
|
||||||
if x:size(1) == 1 then
|
if x:size(1) == 1 then
|
||||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||||
|
@ -211,9 +227,6 @@ function reconstruct.image(model, x, block_size)
|
||||||
return x
|
return x
|
||||||
end
|
end
|
||||||
function reconstruct.scale(model, scale, x, block_size)
|
function reconstruct.scale(model, scale, x, block_size)
|
||||||
if model.w2nn_input_size then
|
|
||||||
block_size = model.w2nn_input_size
|
|
||||||
end
|
|
||||||
local i2rgb = false
|
local i2rgb = false
|
||||||
if x:size(1) == 1 then
|
if x:size(1) == 1 then
|
||||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||||
|
|
|
@ -169,6 +169,17 @@ local function ReLU(backend)
|
||||||
end
|
end
|
||||||
srcnn.ReLU = ReLU
|
srcnn.ReLU = ReLU
|
||||||
|
|
||||||
|
local function Sigmoid(backend)
|
||||||
|
if backend == "cunn" then
|
||||||
|
return nn.Sigmoid(true)
|
||||||
|
elseif backend == "cudnn" then
|
||||||
|
return cudnn.Sigmoid(true)
|
||||||
|
else
|
||||||
|
error("unsupported backend:" .. backend)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
srcnn.ReLU = ReLU
|
||||||
|
|
||||||
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
local function SpatialMaxPooling(backend, kW, kH, dW, dH, padW, padH)
|
||||||
if backend == "cunn" then
|
if backend == "cunn" then
|
||||||
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
return nn.SpatialMaxPooling(kW, kH, dW, dH, padW, padH)
|
||||||
|
@ -821,11 +832,11 @@ function srcnn.upcunet_v2(backend, ch)
|
||||||
local se_fac = 4
|
local se_fac = 4
|
||||||
local con = nn.ConcatTable(2)
|
local con = nn.ConcatTable(2)
|
||||||
local attention = nn.Sequential()
|
local attention = nn.Sequential()
|
||||||
attention:add(nn.SpatialAveragePooling(4, 4, 4, 4))
|
attention:add(SpatialAveragePooling(backend, 4, 4, 4, 4))
|
||||||
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
|
attention:add(SpatialConvolution(backend, n_output, math.floor(n_output / se_fac), 1, 1, 1, 1, 0, 0))
|
||||||
attention:add(nn.ReLU(true))
|
attention:add(nn.ReLU(true))
|
||||||
attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
|
attention:add(SpatialConvolution(backend, math.floor(n_output / se_fac), n_output, 1, 1, 1, 1, 0, 0))
|
||||||
attention:add(nn.Sigmoid(true))
|
attention:add(nn.Sigmoid(true)) -- don't use cudnn sigmoid
|
||||||
attention:add(nn.SpatialUpSamplingNearest(4, 4))
|
attention:add(nn.SpatialUpSamplingNearest(4, 4))
|
||||||
con:add(nn.Identity())
|
con:add(nn.Identity())
|
||||||
con:add(attention)
|
con:add(attention)
|
||||||
|
@ -872,7 +883,8 @@ function srcnn.upcunet_v2(backend, ch)
|
||||||
model.w2nn_scale_factor = 2
|
model.w2nn_scale_factor = 2
|
||||||
model.w2nn_channels = ch
|
model.w2nn_channels = ch
|
||||||
model.w2nn_resize = true
|
model.w2nn_resize = true
|
||||||
model.w2nn_valid_input_size = {76,92,108,140,156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
|
-- {76,92,108,140} are also valid size but it is too small
|
||||||
|
model.w2nn_valid_input_size = {156,172,188,204,220,236,252,268,284,300,316,332,348,364,380,396,412,428,444,460,476,492,508}
|
||||||
|
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue