diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 75344c8..72f40f9 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -204,7 +204,7 @@ function reconstruct.scale_rgb(model, scale, x, offset, block_size, batch_size) collectgarbage() return output end -function reconstruct.image(model, x, block_size) +function reconstruct.image(model, x, block_size, batch_size) local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -216,17 +216,17 @@ function reconstruct.image(model, x, block_size) end if reconstruct.is_rgb(model) then x = reconstruct.image_rgb(model, x, - reconstruct.offset_size(model), block_size) + reconstruct.offset_size(model), block_size, batch_size) else x = reconstruct.image_y(model, x, - reconstruct.offset_size(model), block_size) + reconstruct.offset_size(model), block_size, batch_size) end if i2rgb then x = image.rgb2y(x) end return x end -function reconstruct.scale(model, scale, x, block_size) +function reconstruct.scale(model, scale, x, block_size, batch_size) local i2rgb = false if x:size(1) == 1 then local new_x = torch.Tensor(3, x:size(2), x:size(3)) @@ -239,11 +239,11 @@ function reconstruct.scale(model, scale, x, block_size) if reconstruct.is_rgb(model) then x = reconstruct.scale_rgb(model, scale, x, reconstruct.offset_size(model), - block_size) + block_size, batch_size) else x = reconstruct.scale_y(model, scale, x, reconstruct.offset_size(model), - block_size) + block_size, batch_size) end if i2rgb then x = image.rgb2y(x)