I would like to speed up the following task, where I perform some dot products between given arrays at every t to fill in a matrix at each t:
import numpy as np
N=100
T=120
K=5
L=6
we=np.random.normal(size=(L,L,T))
ga=np.random.normal(size=(L,K))
qu=np.random.normal(size=(L,T))
f1 = np.full((K, T), np.nan)
for t in range(T):
m1 = ga.T.dot(we[:, :, t]).dot(ga)
m2 = ga.T.dot(qu[:, t])
f1[:, t] = np.linalg.solve(m1, m2.reshape((-1, 1))))
To speed up the dot product leading to m1, I used np.matmul
to fit everything into a 120x5x5 array and a 5x120 matrix, respectively:
n1 = np.matmul(ga.T, np.matmul(we.transpose(2,0,1), ga)
n2 = np.dot(ga.T,qu)
This should lead to the same calculations, but casted into a higher order array. Now it comes to the point of np.linalg.solve
, that ideally I would like to be performed without a loop exploiting the logic of the np.matmul
as for the dot product. However:
np.linalg.solve(n1, n2)
throws me a mismatch error. Is there a way to do what I need at once using the matmul logic for solve? What about other functions like np.linalg.svd
?
By convention the matrix multiplication in numpy is summed over the last and the second last axes respectively. So it is good practice to define the "T" axis (or "batch" axis in terms of ML) as the first axis. The same is true for np.linalg.solve
. Now looking at the shapes you obtained:
>>> n1 = np.matmul(ga.T, np.matmul(we.transpose(2,0,1), ga)
>>> print(n1.shape)
(120, 5, 5)
And:
>>> n2 = np.dot(ga.T,qu)
>>> print(n2.shape)
(5, 120)
One can see that the "batch axis" is in a different place. So you just miss an additional transpose:
import numpy as np
N = 100
T = 120
K = 5
L = 6
we = np.random.normal(size=(L, L, T))
ga = np.random.normal(size=(L, K))
qu = np.random.normal(size=(L, T))
f1 = np.full((K, T), np.nan)
for t in range(T):
m1 = ga.T.dot(we[:, :, t]).dot(ga)
m2 = ga.T.dot(qu[:, t])
f1[:, t] = np.linalg.solve(m1, m2)
# vectorized version
m1 = np.matmul(ga.T, np.matmul(we.transpose(2, 0, 1), ga))
m2 = np.matmul(ga.T, qu)
np.allclose(np.linalg.solve(m1, m2.T).T, f1)
Which returns True
.
I hope this helps!