From f16950438cb5482ff91e970b47d7d0076dbe6c10 Mon Sep 17 00:00:00 2001 From: nagadomi Date: Mon, 24 Oct 2016 09:06:27 +0900 Subject: [PATCH] Fix tta-mode --- lib/reconstruct.lua | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 138bb6f..68d2736 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -172,6 +172,9 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size) return output end function reconstruct.image(model, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -194,6 +197,9 @@ function reconstruct.image(model, x, block_size) return x end function reconstruct.scale(model, scale, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -287,6 +293,9 @@ local function tta(f, n, model, x, block_size) return average:div(#augments) end function reconstruct.image_tta(model, n, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end if reconstruct.is_rgb(model) then return tta(reconstruct.image_rgb, n, model, x, block_size) else @@ -294,6 +303,9 @@ function reconstruct.image_tta(model, n, x, block_size) end end function reconstruct.scale_tta(model, n, scale, x, block_size) + if model.w2nn_input_size then + block_size = model.w2nn_input_size + end if reconstruct.is_rgb(model) then local f = function (model, x, offset, block_size) return reconstruct.scale_rgb(model, scale, x, offset, block_size)