What's the difference between torch.stack() and torch.cat()?

PythonPytorch

Python Problem Overview


What's the difference between torch.cat and torch.stack?


OpenAI's REINFORCE and actor-critic examples for reinforcement learning have the following:

# REINFORCE:
policy_loss = torch.cat(policy_loss).sum()

# actor-critic:
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

Python Solutions


Solution 1 - Python

stack

> Concatenates sequence of tensors along a new dimension.

cat

> Concatenates the given sequence of seq tensors in the given dimension.

So if A and B are of shape (3, 4):

  • torch.cat([A, B], dim=0) will be of shape (6, 4)
  • torch.stack([A, B], dim=0) will be of shape (2, 3, 4)

Solution 2 - Python

t1 = torch.tensor([[1, 2],
                   [3, 4]])

t2 = torch.tensor([[5, 6],
                   [7, 8]])
torch.stack torch.cat
'Stacks' a sequence of tensors along a new dimension:

enter image description here



'Concatenates' a sequence of tensors along an existing dimension:

enter image description here

These functions are analogous to numpy.stack and numpy.concatenate.

Solution 3 - Python

The original answer lacks a good example that is self-contained so here it goes:

import torch

# stack vs cat

# cat "extends" a list in the given dimension e.g. adds more rows or columns

x = torch.randn(2, 3)
print(f'{x.size()}')

# add more rows (thus increasing the dimensionality of the column space to 2 -> 6)
xnew_from_cat = torch.cat((x, x, x), 0)
print(f'{xnew_from_cat.size()}')

# add more columns (thus increasing the dimensionality of the row space to 3 -> 9)
xnew_from_cat = torch.cat((x, x, x), 1)
print(f'{xnew_from_cat.size()}')

print()

# stack serves the same role as append in lists. i.e. it doesn't change the original
# vector space but instead adds a new index to the new tensor, so you retain the ability
# get the original tensor you added to the list by indexing in the new dimension
xnew_from_stack = torch.stack((x, x, x, x), 0)
print(f'{xnew_from_stack.size()}')

xnew_from_stack = torch.stack((x, x, x, x), 1)
print(f'{xnew_from_stack.size()}')

xnew_from_stack = torch.stack((x, x, x, x), 2)
print(f'{xnew_from_stack.size()}')

# default appends at the from
xnew_from_stack = torch.stack((x, x, x, x))
print(f'{xnew_from_stack.size()}')

print('I like to think of xnew_from_stack as a \"tensor list\" that you can pop from the front')

output:

torch.Size([2, 3])
torch.Size([6, 3])
torch.Size([2, 9])
torch.Size([4, 2, 3])
torch.Size([2, 4, 3])
torch.Size([2, 3, 4])
torch.Size([4, 2, 3])
I like to think of xnew_from_stack as a "tensor list"

for reference here are the definitions:

> cat: Concatenates the given sequence of seq tensors in the given dimension. The consequence is that a specific dimension changes size e.g. dim=0 then you are adding elements to the row which increases the dimensionality of the column space.

> stack: Concatenates sequence of tensors along a new dimension. I like to think of this as the torch "append" operation since you can index/get your original tensor by "poping it" from the front. With no arguments, it appends tensors to the front of the tensor.


Related:


Update: With nested list of the same size

def tensorify(lst):
    """
    List must be nested list of tensors (with no varying lengths within a dimension).
    Nested list of nested lengths [D1, D2, ... DN] -> tensor([D1, D2, ..., DN)

    :return: nested list D
    """
    # base case, if the current list is not nested anymore, make it into tensor
    if type(lst[0]) != list:
        if type(lst) == torch.Tensor:
            return lst
        elif type(lst[0]) == torch.Tensor:
            return torch.stack(lst, dim=0)
        else:  # if the elements of lst are floats or something like that
            return torch.tensor(lst)
    current_dimension_i = len(lst)
    for d_i in range(current_dimension_i):
        tensor = tensorify(lst[d_i])
        lst[d_i] = tensor
    # end of loop lst[d_i] = tensor([D_i, ... D_0])
    tensor_lst = torch.stack(lst, dim=0)
    return tensor_lst

here is a few unit tests (I didn't write more tests but it worked with my real code so I trust it's fine. Feel free to help me by adding more tests if you want):


def test_tensorify():
    t = [1, 2, 3]
    print(tensorify(t).size())
    tt = [t, t, t]
    print(tensorify(tt))
    ttt = [tt, tt, tt]
    print(tensorify(ttt))

if __name__ == '__main__':
    test_tensorify()
    print('Done\a')

Solution 4 - Python

If someone is looking into the performance aspects of this, I've done a small experiment. In my case, I needed to convert a list of scalar tensors into a single tensor.

import torch
torch.__version__ # 1.10.2
x = [torch.randn(1) for _ in range(10000)]
torch.cat(x).shape, torch.stack(x).shape # torch.Size([10000]), torch.Size([10000, 1])

%timeit torch.cat(x) # 1.5 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.cat(x).reshape(-1,1) # 1.95 ms ± 534 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.stack(x) # 5.36 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

My conclusion is that even if you want to have the additional dimension of torch.stack, using torch.cat and then reshape is better.

Note: this post is taken from the PyTorch forum (I am the author of the original post)

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionGulzarView Question on Stackoverflow
Solution 1 - PythonJatentakiView Answer on Stackoverflow
Solution 2 - PythoniacobView Answer on Stackoverflow
Solution 3 - PythonCharlie ParkerView Answer on Stackoverflow
Solution 4 - PythoncasualcausalityView Answer on Stackoverflow