Fix tta-mode
This commit is contained in:
parent
61aeb46303
commit
f16950438c
|
@ -172,6 +172,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size)
|
|||
return output
|
||||
end
|
||||
function reconstruct.image(model, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
|
@ -194,6 +197,9 @@ function reconstruct.image(model, x, block_size)
|
|||
return x
|
||||
end
|
||||
function reconstruct.scale(model, scale, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
local i2rgb = false
|
||||
if x:size(1) == 1 then
|
||||
local new_x = torch.Tensor(3, x:size(2), x:size(3))
|
||||
|
@ -287,6 +293,9 @@ local function tta(f, n, model, x, block_size)
|
|||
return average:div(#augments)
|
||||
end
|
||||
function reconstruct.image_tta(model, n, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
if reconstruct.is_rgb(model) then
|
||||
return tta(reconstruct.image_rgb, n, model, x, block_size)
|
||||
else
|
||||
|
@ -294,6 +303,9 @@ function reconstruct.image_tta(model, n, x, block_size)
|
|||
end
|
||||
end
|
||||
function reconstruct.scale_tta(model, n, scale, x, block_size)
|
||||
if model.w2nn_input_size then
|
||||
block_size = model.w2nn_input_size
|
||||
end
|
||||
if reconstruct.is_rgb(model) then
|
||||
local f = function (model, x, offset, block_size)
|
||||
return reconstruct.scale_rgb(model, scale, x, offset, block_size)
|
||||
|
|
Loading…
Reference in a new issue