What does "unsqueeze" do in Pytorch?

PythonPytorchTorch

Python Problem Overview


The PyTorch documentation says: >Returns a new tensor with a dimension of size one inserted at the specified position. [...] > > >>> x = torch.tensor([1, 2, 3, 4]) > >>> torch.unsqueeze(x, 0) > tensor([[ 1, 2, 3, 4]]) > >>> torch.unsqueeze(x, 1) > tensor([[ 1], > [ 2], > [ 3], > [ 4]])

Python Solutions


Solution 1 - Python

If you look at the shape of the array before and after, you see that before it was (4,) and after it is (1, 4) (when second parameter is 0) and (4, 1) (when second parameter is 1). So a 1 was inserted in the shape of the array at axis 0 or 1, depending on the value of the second parameter.

That is opposite of np.squeeze() (nomenclature borrowed from MATLAB) which removes axes of size 1 (singletons).

Solution 2 - Python

unsqueeze turns an n.d. tensor into an (n+1).d. one by adding an extra dimension of depth 1. However, since it is ambiguous which axis the new dimension should lie across (i.e. in which direction it should be "unsqueezed"), this needs to be specified by the dim argument.

e.g. unsqueeze can be applied to a 2d tensor three different ways:

enter image description here

The resulting unsqueezed tensors have the same information, but the indices used to access them are different.

Solution 3 - Python

It indicates the position on where to add the dimension. torch.unsqueeze adds an additional dimension to the tensor.

So let's say you have a tensor of shape (3), if you add a dimension at the 0 position, it will be of shape (1,3), which means 1 row and 3 columns:

  • If you have a 2D tensor of shape (2,2) add add an extra dimension at the 0 position, this will result of the tensor having a shape of (1,2,2), which means one channel, 2 rows and 2 columns. If you add at the 1 position, it will be of shape (2,1,2), so it will have 2 channels, 1 row and 2 columns.
  • If you add at the 1 position, it will be (3,1), which means 3 rows and 1 column.
  • If you add it at the 2 position, the tensor will be of shape (2,2,1), which means 2 channels, 2 rows and one column.

Solution 4 - Python

Here are the descriptions from the PyTorch docs:

>### torch.squeeze(input, dim=None, *, out=None) → Tensor > >Returns a tensor with all the dimensions of input of size 1 removed. > >For example, if input is of shape: (A×1×B×C×1×D) then the out tensor will be of shape: (A×B×C×D) . > >When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: (A×1×B) , squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape (A×B) .

>### torch.unsqueeze(input, dim) → Tensor > >Returns a new tensor with a dimension of size one inserted at the specified position. > >The returned tensor shares the same underlying data with this tensor. > >A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.

Solution 5 - Python

unsqueeze is a method to change the tensor dimensions, such that operations such as tensor multiplication can be possible. This basically alters the dimension to produce a tensor that has a different dimension.

For example: If you want to multiply your tensor of size(4), with a tensor that has the size (4, N, N) then you'll get an error. But using the unsqueeze method, you can convert the tensor to size (4,1,1). Now since this has an operand of size 1, you'll be able to multiply both the tensors.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionStarckOverflarView Question on Stackoverflow
Solution 1 - Pythonnorok2View Answer on Stackoverflow
Solution 2 - PythoniacobView Answer on Stackoverflow
Solution 3 - PythonVoontentView Answer on Stackoverflow
Solution 4 - PythonprostiView Answer on Stackoverflow
Solution 5 - PythonR.Sankar MaheshView Answer on Stackoverflow