pytorch

How to do batched dot product in PyTorch?


I have a input tensor that is of size [B, N, 3] and I have a test tensor of size [N, 3] . I want to apply a dot product of the two tensors such that I get [B, N] basically. Is this actually possible?


Solution

  • Yes, it's possible:

    a = torch.randn(5, 4, 3)
    b = torch.randn(4, 3)
    
    c = torch.einsum('ijk,jk->ij', a, b) # torch.Size([5, 4])