I have two 2-d tensors, which align via broadcasting, so if I add/subtract them, I incur a huge 3-d tensor. I don't really need that though, since I'll be performing a mean
on one dimension. In this demo, I unsqueeze the tensors to show how they align, but they are 2-d otherwise.
x = torch.tensor(...) # (batch , 1, B)
y = torch.tensor(...) # (1, , A, B)
out = torch.cos(x - y).mean(dim=2) # (batch, B)
Possible Solutions:
An algebraic simplification, but for the life of me I haven't solved this yet.
Some PyTorch primitive that'll help? This is cosine similarity, but, a bit different than torch.cosine_similarity
. I'm applying it to complex numbers' .angle()
s.
Custom C/CPython code that loops efficiently.
Other?
To save memory I recommend using torch.einsum
:
We can make use of the trigonometric identity
cos(x-y) = cos(x)*cos(y) + sin(x)*sin(y)
In this case we can apply einsum
where the usual summing will be the averaging, and the +
between the two produces will be another operation later, so in short
xs, ys = torch.sin(x), torch.sin(y)
xc, yc = torch.cos(x), torch.cos(y)
# use einsum for sin/cos products and averaging sum, use + for sum of products:
out = (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
While measuring the memory consumption is a little bit tedious, I resorted to just measuring time as a proxy. Here you can see your original method and my proposal for various sizes of inputs. (The script for generating these plots is attached below.)
import matplotlib.pyplot as plt
import torch
import time
def main():
ns = torch.logspace(1, 3.2, 20).to(torch.long)
tns = []; tes = []
for n in ns:
tn, te = compare(n)
tns.append(tn); tes.append(te)
plt.loglog(ns, tns, ':.'); plt.loglog(ns, tes, '.-'); plt.loglog(ns, 1e-6*ns**1, ':'); plt.loglog(ns, 1e-6*ns**2, ':'); plt.legend(['naive', 'einsum', 'x^1', 'x^2']);
plt.show()
def compare(n):
batch = a = b = n
x = torch.zeros((batch, b)) # (batch , 1, B)
y = torch.zeros((a, b)) # (1, , A, B)
t = time.perf_counter(); ra = af(x.unsqueeze(1), y.unsqueeze(0)); print('naive method', tn := time.perf_counter() - t)
t = time.perf_counter(); rb = bf(x, y); print('einsum method', te := time.perf_counter() - t)
print((ra-rb).abs().max()) # verify we have same results
return tn, te
def af(x, y):
return torch.cos(x - y).mean(dim=2)
def bf(x, y):
xs, ys = torch.sin(x), torch.sin(y)
xc, yc = torch.cos(x), torch.cos(y)
return (torch.einsum('i k, j k -> i k', xs, ys) + torch.einsum('i k, j k -> i k', xc, yc)) / y.shape[1]
main()