If there are 2 tensors of the following sizes.
A = [N x L x T]
B = [N x T x K]
Then I would like to do a matrix multiplication
of slices from the 2 tensors. like below.
matmul_slice = A[0,:,:] @ B[0,:,:] = [L x T] @ [T x K] = [L x K]
Then I would like to do it N
times along the dimension = 0
.
So that I end up with the final matrix with size [N,L,K]
I do not want to use loop over N since it slows down the computation. I have been playing around with torch.matmul
and einsum
, but I cannot get the correct answer.
How can I achieve this in a compact way?
torch.bmm
is what your need, although torch.matmul
should be equivalent in your case. I think you should recheck your computation.