Is there any way optimizing the performance speed of this function?
def func(X):
n, p = X.shape
R = np.eye(p)
delta = 0.0
for i in range(100):
delta_old = delta
Y = X @ R
alpha = 1. / n
Y2 = Y**2
Y3 = Y2 * Y
W = np.sum(Y2, axis=0)
transformed = X.T @ (Y3 - (alpha * Y * W))
U, svals, VT = np.linalg.svd(transformed, full_matrices=False)
R = U @ VT # is used as a stopping criterion
delta = np.sum(svals)
return R
Naively, I thought using numba
would help because of the loop (the actual number of loops is higher),
from numba import jit
@jit(nopython=True, parallel=True)
def func_numba(X):
n, p = X.shape
R = np.eye(p)
delta = 0.0
for i in range(100):
delta_old = delta
Y = X @ R
alpha = 1. / n
Y2 = Y**2
Y3 = Y2 * Y
W = np.sum(Y2, axis=0)
transformed = X.T @ (Y3 - (alpha * Y * W))
U, svals, VT = np.linalg.svd(transformed, full_matrices=False)
R = U @ VT
delta = np.sum(svals) # is used as a stopping criterion
return R
but to my surprise the numbaized function is actually slower. Why is numba not more effective in this case? Is there another option for me (preferably using numpy)?
Note: You can assume X
to be a "tall-and-skinny" matrix.
import numpy as np
size = (10_000, 15)
X = np.random.normal(size=size)
%timeit func(X) # 1.28 s
%timeit func_numba(X) # 2.05 s
I tried to rewrite some stuff to speed things up. The only things I changed (apart from some formatting maybe) were removing the computation of the unused delta
, pulling the transposition of X
out of the loop as well as factorizing out the multiplication with Y
for the computation of transformed
which results in one fewer multiplications.
def func2(X):
n, p = X.shape
R = np.eye(p)
alpha = 1. / n
XT = X.T
for i in range(100):
Y = X @ R
Y2 = Y**2
W = np.sum(Y2, axis=0)
transformed = XT @ (Y * (Y2 - (alpha * W)))
U, svals, VT = np.linalg.svd(transformed, full_matrices=False)
R = U @ VT
return R
When I compare this func2
it to the function func
as follows
X = np.random.normal(size=(10_000, 15))
assert np.allclose(func(X), func2(X))
%timeit func(X)
%timeit func2(X)
I get a speedup of more than 1.5x (it's not always as nice as that, however)
197 ms ± 44.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
127 ms ± 21.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)