How to do product of matrices in PyTorch

PythonMatrixPytorch

Python Problem Overview


In numpy I can do a simple matrix multiplication like this:

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))

However, when I am trying this with PyTorch Tensors, this does not work:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())

print(b)
print(b.size())

print(torch.dot(a, b))

This code throws the following error:

> RuntimeError: inconsistent tensor size at > /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

Any ideas how matrix multiplication can be conducted in PyTorch?

Python Solutions


Solution 1 - Python

You're looking for

torch.mm(a,b)

Note that torch.dot() behaves differently to np.dot(). There's been some discussion about what would be desirable here. Specifically, torch.dot() treats both a and b as 1D vectors (irrespective of their original shape) and computes their inner product. The error is thrown, because this behaviour makes your a a vector of length 6 and your b a vector of length 2; hence their inner product can't be computed. For matrix multiplication in PyTorch, use torch.mm(). Numpy's np.dot() in contrast is more flexible; it computes the inner product for 1D arrays and performs matrix multiplication for 2D arrays.

By popular demand, the function torch.matmul performs matrix multiplications if both arguments are 2D and computes their dot product if both arguments are 1D. For inputs of such dimensions, its behaviour is the same as np.dot. It also lets you do broadcasting or matrix x matrix, matrix x vector and vector x vector operations in batches. For more info, see its docs.

# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])

# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])

Solution 2 - Python

If you want to do a matrix (rank 2 tensor) multiplication you can do it in four equivalent ways:

AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+

There are a few subtleties. From the PyTorch documentation:

> torch.mm does not broadcast. For broadcasting matrix products, > see torch.matmul().

For instance, you cannot multiply two 1-dimensional vectors with torch.mm, nor multiply batched matrices (rank 3). To this end, you should use the more versatile torch.matmul. For an extensive list of the broadcasting behaviours of torch.matmul, see the documentation.

For element-wise multiplication, you can simply do (if A and B have the same shape)

A * B # element-wise matrix multiplication (Hadamard product)

Solution 3 - Python

Use torch.mm(a, b) or torch.matmul(a, b)
Both are same.

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

There's one more option that may be good to know. That is @ operator. @Simon H.

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

The three give the same results.

Related links:
Matrix multiplication operator
PEP 465 -- A dedicated infix operator for matrix multiplication

Solution 4 - Python

You can use "@" for computing a dot product between two tensors in pytorch.

a = torch.tensor([[1,2],
                  [3,4]])
b = torch.tensor([[5,6],
                  [7,8]])
c = a@b #For dot product
c

d = a*b #For elementwise multiplication 
d

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionblckbirdView Question on Stackoverflow
Solution 1 - PythonmbpaulusView Answer on Stackoverflow
Solution 2 - PythonBiBiView Answer on Stackoverflow
Solution 3 - PythonDavid JungView Answer on Stackoverflow
Solution 4 - PythonNivesh GadipudiView Answer on Stackoverflow