(PyTorch) 10X Speed up! Stack as an Alternative for Concat

 

Shape Difference

Both torch.cat() and torch.stack() will concatenate a sequence of tensors. torch.cat() will concatenate tensors along a given dimension rather than a new dimension.

import torch

concat_tensor = torch.cat([torch.rand(2,3) for _ in range(2)], dim=0)
print(f'The shape of concatenated tensor is {concat_tensor.shape}')

stack_tensor = torch.stack([torch.rand(2,3) for _ in range(2)], dim=0)
print(f'The shape of stacked tensor is {stack_tensor.shape}')

Output:

The shape of concatenated tensor is torch.Size([4, 3])
The shape of stacked tensor is torch.Size([2, 2, 3])

Speed Difference

torch.stack() could use binary method to accelerate concatenation, which could be $10\times$ faster than torch.cat(). The usage is to firstly append tensors into a list and then stack them.

import torch, time
# prepare tensors
tensor_lists =[[torch.rand(2,3) for _ in range(10000)]]*10
duration_concat = []
duration_stack = []

for t_list in tensor_lists:
    begin = time.time()
    for step, tensor in enumerate(t_list):
        if step == 0:
            new_tensor = tensor
        else:
            new_tensor = torch.cat((new_tensor, tensor), dim=0)
    end = time.time()
    duration_concat.append(end-begin)
print(f'Average time of 10 times concatenation is {sum(duration_concat)/len(duration_concat):.5f}')

for t_list in tensor_lists:
    begin = time.time()
    new_tensor = torch.stack(t_list, dim=0)
    new_tensor.view([-1, new_tensor.shape[-1]])
    end = time.time()
    duration_stack.append(end-begin)
print(f'Average time of 10 times stacking is {sum(duration_stack)/len(duration_stack):.5f}')

Output:

Average time of 10 times concatenation is 0.17045
Average time of 10 times stacking is 0.02236