I need to compute the (log of the) determinant of the Gram matrix of a matrix A
and I was wondering if there is a way to compute this efficiently and in a stable way in Numpy/Scipy.
import numpy as np
m, n = 100, 150
J = np.random.randn(m, n)
np.log(np.det(J.dot(J.T)))
is there some LAPACK routine or some math trick I could use to speed things up and make it more stable?
For better numerical stability, I would suggest to use slogdet, which is your main aim in any case. There may also be a very minimal gain if you use np.inner(J, J)
instead of J.dot(J.T)
. For really speeding things up, I would recommend using jax.numpy.
import numpy as np
import jax
import jax.numpy as jnp
m, n = 100, 150
J = np.random.randn(m, n)
def a(J):
return np.log(np.linalg.det(J.dot(J.T)))
def b(J):
return np.linalg.slogdet(np.inner(J, J))[1]
def c(J):
return jnp.linalg.slogdet(jnp.inner(J, J))[1]
# jit + compile
d = jax.jit(c)
d(J)
# check correctness
print(np.allclose(a(J), b(J))) # True
print(np.allclose(a(J), c(J))) # True
print(np.allclose(a(J), d(J))) # True
Checking run times, on Google Colab:
%timeit -n 1000 -r 10 a(J)
# 240 µs ± 16.2 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
%timeit -n 1000 -r 10 b(J)
# 227 µs ± 10.2 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
J_dev = jax.device_put(J)
%timeit -n 1000 -r 10 c(J_dev).block_until_ready()
# 112 µs ± 4.46 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
%timeit -n 1000 -r 10 d(J_dev).block_until_ready()
# 96.2 µs ± 4.23 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
So rougly about ~2x speedup is possible this way.