pytorchmatmul

torch matmul between 2D and 3D tensor


Hi I have two tensors:

a = torch.randn(125, 128)    # Shape: (125, 128)
b = torch.randn(128, 8, 64)  # Shape: (128, 8, 64)

I want the result has a shape of (125, 8, 64)

My first observation is: last dimension of a match the first dimension of b then I do:

result = torch.matmul(a,b)

It gave me the error:

Expected size for first two dimensions of batch2 tensor to be: [128, 128] but got: [128, 8].

How can I do this.

Edit: I also dont' want to reshape into 2D and then reshape the result into 3D again.


Solution

  • You can use an einsum

    a = torch.randn(125, 128)    # Shape: (125, 128)
    b = torch.randn(128, 8, 64)  # Shape: (128, 8, 64)
    c = torch.einsum('ij,jkl->ikl', a, b)
    print(c.shape)
    > torch.Size([125, 8, 64])