Hi I have two tensors:
a = torch.randn(125, 128) # Shape: (125, 128)
b = torch.randn(128, 8, 64) # Shape: (128, 8, 64)
I want the result has a shape of (125, 8, 64)
My first observation is: last dimension of a match the first dimension of b then I do:
result = torch.matmul(a,b)
It gave me the error:
Expected size for first two dimensions of batch2 tensor to be: [128, 128] but got: [128, 8].
How can I do this.
Edit: I also dont' want to reshape into 2D and then reshape the result into 3D again.
You can use an einsum
a = torch.randn(125, 128) # Shape: (125, 128)
b = torch.randn(128, 8, 64) # Shape: (128, 8, 64)
c = torch.einsum('ij,jkl->ikl', a, b)
print(c.shape)
> torch.Size([125, 8, 64])