pytorchnumpy-einsum

How to compute the outer sum (similar to outer product


Given tensors x and y, each with shape (num_batches, d), how can I use PyTorch to compute the sum of every combination of x and y within a batch?

This is similar to outer product, except we don't want to multiply, but sum. (This implies that I could solve this by exponentiating, outer product, and taking the log, but of course that has numerical and performance disadvantages).

It could be done via cartesian product and then summing each of the combinations.

Essentially, I'd like osum[b, i, j] == x[b, i] + y[b, j]. Can PyTorch do this in tensors, without loops?


Solution

  • This can easily be done, by introducing singleton dimensions into x and y and broadcasting along these singleton dimensions:

    osum = x[..., None] + y[:, None, :]
    

    For example:

    x = torch.arange(6).view(2,3)
    y = x * 10
    osum = x[..., None] + y[:, None, :]
    

    Results with:

    tensor([[[ 0, 10, 20],
             [ 1, 11, 21],
             [ 2, 12, 22]],
    
            [[33, 43, 53],
             [34, 44, 54],
             [35, 45, 55]]])
    

    Update (July, 14th): How it works?

    You have two tensors, x and y of shape bxn, and you want to compute:

    osum[b,i,j] = x[b, i] + y[b, j]
    

    We can, conceptually, create new variables xx and yy by repeating each element of x and y along a third dimension, such that:

    xx[b, i, j] == x[b, i]  # for all j
    yy[b, i, j] == y[b, j]  # for all i
    

    With these new variables, it is easy to see that:

    osum = xx + yy
    

    since, by deinition

    osum[b, i, j] == xx[b, i, j] + yy[b, i, j] == x[b, i] + y[b, j]
    

    Now, you can use commands such as torch.expand or torch.repeat to explicitly create xx and yy - but why bother? since their elements are just trivial repetitions of the elements along specific dimensions, broadcasting does this implicitly for you.