pythonpytorchcosine-similarityarray-broadcastingnumpy-einsum

In PyTorch, how can I avoid an expensive broadcast when adding two tensors then immediately collapsing?


I have two 2-d tensors, which align via broadcasting, so if I add/subtract them, I incur a huge 3-d tensor. I don't really need that though, since I'll be performing a mean on one dimension. In this demo, I unsqueeze the tensors to show how they align, but they are 2-d otherwise.

x = torch.tensor(...)              # (batch , 1,  B)
y = torch.tensor(...)              # (1,    , A,  B)
out = torch.cos(x - y).mean(dim=2) # (batch, B)

Possible Solutions:


Solution

  • To save memory I recommend using torch.einsum: We can make use of the trigonometric identity

    cos(x-y) = cos(x)*cos(y) + sin(x)*sin(y)
    

    In this case we can apply einsum where the usual summing will be the averaging, and the + between the two produces will be another operation later, so in short

    xs, ys = torch.sin(x), torch.sin(y)
    xc, yc = torch.cos(x), torch.cos(y)
    # use einsum for sin/cos products and averaging sum, use + for sum of products: 
    out = (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
    

    While measuring the memory consumption is a little bit tedious, I resorted to just measuring time as a proxy. Here you can see your original method and my proposal for various sizes of inputs. (The script for generating these plots is attached below.)

    enter image description here

    import matplotlib.pyplot as plt
    import torch
    import time
    
    def main():
        ns = torch.logspace(1, 3.2, 20).to(torch.long)
        tns = []; tes = []
        for n in ns:
            tn, te = compare(n)
            tns.append(tn); tes.append(te)
        plt.loglog(ns, tns, ':.'); plt.loglog(ns, tes, '.-'); plt.loglog(ns, 1e-6*ns**1, ':'); plt.loglog(ns, 1e-6*ns**2, ':'); plt.legend(['naive', 'einsum', 'x^1', 'x^2']);
        plt.show()
    
    def compare(n):
        batch = a = b = n
        x = torch.zeros((batch, b)) # (batch , 1,  B)
        y = torch.zeros((a, b))  # (1,    , A,  B)
        t = time.perf_counter(); ra = af(x.unsqueeze(1), y.unsqueeze(0)); print('naive method', tn := time.perf_counter() - t)
        t = time.perf_counter(); rb = bf(x, y); print('einsum method', te := time.perf_counter() - t)
        print((ra-rb).abs().max()) # verify we have same results
        return tn, te
    
    def af(x, y):
        return torch.cos(x - y).mean(dim=2) 
    
    def bf(x, y):
        xs, ys = torch.sin(x), torch.sin(y)
        xc, yc = torch.cos(x), torch.cos(y)
        return (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
    
    main()