With numpy, I can do a simple matrix multiplication like this:
a = numpy.ones((3, 2))
b = numpy.ones((2, 1))
result = a.dot(b)
However, this does not work with PyTorch:
a = torch.ones((3, 2))
b = torch.ones((2, 1))
result = torch.dot(a, b)
This code throws the following error:
RuntimeError: 1D tensors expected, but got 2D and 2D tensors
How do I perform matrix multiplication in PyTorch?
Use torch.mm
:
torch.mm(a, b)
torch.dot()
behaves differently to np.dot()
. There's been some discussion about what would be desirable here. Specifically, torch.dot()
treats both a
and b
as 1D vectors (irrespective of their original shape) and computes their inner product. The error is thrown because this behaviour makes your a
a vector of length 6 and your b
a vector of length 2; hence their inner product can't be computed. For matrix multiplication in PyTorch, use torch.mm()
. Numpy's np.dot()
in contrast is more flexible; it computes the inner product for 1D arrays and performs matrix multiplication for 2D arrays.
torch.matmul
performs matrix multiplications if both arguments are 2D
and computes their dot product if both arguments are 1D
. For inputs of such dimensions, its behaviour is the same as np.dot
. It also lets you do broadcasting or matrix x matrix
, matrix x vector
and vector x vector
operations in batches.
# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])
# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])