1
0
Fork 0
mirror of synced 2024-05-16 19:02:21 +12:00

performance tuning

This commit is contained in:
nagadomi 2018-11-04 00:21:56 +09:00
parent 6655d03971
commit daadbaccae

View file

@ -30,38 +30,38 @@ local function reconstruct_nn(model, x, inner_scale, offset, block_size, batch_s
end
end
end
local input = torch.Tensor(batch_size, ch, input_block_size, input_block_size)
local input_cuda = torch.CudaTensor(batch_size, ch, input_block_size, input_block_size)
for i = 1, #input_indexes, batch_size do
local c = 0
local output
for j = 0, batch_size - 1 do
if i + j > #input_indexes then
break
local input = torch.Tensor(#input_indexes, ch, input_block_size, input_block_size)
local input_cuda = torch.CudaTensor():resize(input:size())
local output_cuda = torch.CudaTensor():resize(new_x:size())
for i = 1, #input_indexes do
input[i]:copy(x[input_indexes[i]])
if model.w2nn_gcn then
local mean = input[i]:mean()
local stdv = input[i]:std()
if stdv > 0 then
input[i]:add(-mean):div(stdv)
else
input[i]:add(-mean)
end
input[j+1]:copy(x[input_indexes[i + j]])
if model.w2nn_gcn then
local mean = input[j + 1]:mean()
local stdv = input[j + 1]:std()
if stdv > 0 then
input[j + 1]:add(-mean):div(stdv)
else
input[j + 1]:add(-mean)
end
end
c = c + 1
end
input_cuda:copy(input)
if c == batch_size then
output = model:forward(input_cuda)
else
output = model:forward(input_cuda:narrow(1, 1, c))
end
--output = output:view(batch_size, ch, output_size, output_size)
for j = 0, c - 1 do
new_x[output_indexes[i + j]]:copy(output[j+1])
end
end
input_cuda:copy(input)
local batch_n = math.floor(#input_indexes / batch_size)
local batch_rem = #input_indexes % batch_size
for i = 1, batch_n * batch_size, batch_size do
local output = model:forward(input_cuda:narrow(1, i, batch_size))
for j = 0, batch_size - 1 do
output_cuda[output_indexes[i + j]]:copy(output[j + 1])
end
end
if batch_rem > 0 then
local i = 1 + batch_n * batch_size
local output = model:forward(input_cuda:narrow(1, i, batch_rem))
for j = 0, batch_rem - 1 do
output_cuda[output_indexes[i + j]]:copy(output[j+1])
end
end
new_x:copy(output_cuda)
return new_x
end
local reconstruct = {}