I have two approaches of exponentiating a matrix in jnp = jax.numpy
. A
straightforward one:
jnp.exp(-X/reg)
And with some additional actions:
def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)
However, when I tested them:
%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()
The second approach turned to outperform, despite having superficially some additional overhead. I've run a %timeit
with a matrix of size 2000 x 2000:
7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Why it may be the case?
The difference here is order of operations.
In jnp.exp(-X/reg)
, you are negating every entry of X
, and then dividing each entry of the result by reg
. That's two passes over the array X
.
in exp_reg
you are negating reg
(which presumably is a scalar value?) and then dividing X
by the result. That's one pass over X
.
If X
is large, I would expect the first approach to be slightly slower than the second, due to the multiple passes over X
.
Fortunately, since you're using JAX, you can jit
compile your code, in which case XLA generally can optimize over equivalent orders of operation like these. Indeed, for your two functions, compilation eliminates the discrepancy:
from jax import jit
import jax.numpy as jnp
import numpy as np
def exp_reg1(X, reg):
return jnp.exp(-X/reg)
def exp_reg2(X, reg):
K = jnp.divide(X, -reg)
return jnp.exp(K)
X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0
%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop
# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)
%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop
(side note: there's no reason to pre-allocate an empty array K
before assigning the result of an operation to a variable of the same name).