From daadbaccaea4943c5e5a9be2f1771f716b41643a Mon Sep 17 00:00:00 2001 From: nagadomi Date: Sun, 4 Nov 2018 00:21:56 +0900 Subject: [PATCH] performance tuning --- lib/reconstruct.lua | 58 ++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/lib/reconstruct.lua b/lib/reconstruct.lua index 72f40f9..23c4cdc 100644 --- a/lib/reconstruct.lua +++ b/lib/reconstruct.lua @@ -30,38 +30,38 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s end end end - local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size) - local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size) - for i = 1, #input_indexes, batch_size do - local c = 0 - local output - for j = 0, batch_size - 1 do - if i + j > #input_indexes then - break + local input = torch.Tensor(#input_indexes, ch, input_block_size, input_block_size) + local input_cuda = torch.CudaTensor():resize(input:size()) + local output_cuda = torch.CudaTensor():resize(new_x:size()) + for i = 1, #input_indexes do + input[i]:copy(x[input_indexes[i]]) + if model.w2nn_gcn then + local mean = input[i]:mean() + local stdv = input[i]:std() + if stdv > 0 then + input[i]:add(-mean):div(stdv) + else + input[i]:add(-mean) end - input[j+1]:copy(x[input_indexes[i + j]]) - if model.w2nn_gcn then - local mean = input[j + 1]:mean() - local stdv = input[j + 1]:std() - if stdv > 0 then - input[j + 1]:add(-mean):div(stdv) - else - input[j + 1]:add(-mean) - end - end - c = c + 1 - end - input_cuda:copy(input) - if c == batch_size then - output = model:forward(input_cuda) - else - output = model:forward(input_cuda:narrow(1, 1, c)) - end - --output = output:view(batch_size, ch, output_size, output_size) - for j = 0, c - 1 do - new_x[output_indexes[i + j]]:copy(output[j+1]) end end + input_cuda:copy(input) + local batch_n = math.floor(#input_indexes / batch_size) + local batch_rem = #input_indexes % batch_size + for i = 1, batch_n * batch_size, batch_size do + local output = model:forward(input_cuda:narrow(1, i, batch_size)) + for j = 0, batch_size - 1 do + output_cuda[output_indexes[i + j]]:copy(output[j + 1]) + end + end + if batch_rem > 0 then + local i = 1 + batch_n * batch_size + local output = model:forward(input_cuda:narrow(1, i, batch_rem)) + for j = 0, batch_rem - 1 do + output_cuda[output_indexes[i + j]]:copy(output[j+1]) + end + end + new_x:copy(output_cuda) return new_x end local reconstruct = {}