pythonjax

Why is JAX's jit compilation slower on the second run in my example?


I am new to using JAX, and I’m still getting familiar with how it works. From what I understand, when using Just-In-Time (JIT) compilation (jax.jit), the first execution of a function might be slower due to the compilation overhead, but subsequent executions should be faster. However, I am seeing the opposite behavior.

In the following code snippet:

from icecream import ic
import jax
from time import time
import numpy as np


@jax.jit
def my_function(x, y):
    return x @ y


vectorized_function = jax.vmap(my_function, in_axes=(0, None))

shape = (1_000_000, 1_000)

x = np.ones(shape)
y = np.ones(shape[1])

start = time()
vectorized_function(x, y)
t_1 = time() - start

start = time()
vectorized_function(x, y)
t_2 = time() - start

print(f'{t_1 = }\n{t_2 = }')

I get the following results:

t_1 = 13.106784582138062
t_2 = 15.664098024368286

As you can see, the second run (t_2) is actually slower than the first one (t_1), which seems counterintuitive to me. I expected the second run to be faster due to JAX’s JIT caching.

Has anyone encountered a similar situation or have any insights into why this might be happening?

PS: I know I could have done x @ y directly without invoking vmap, but this is an easy example just to test its behaviour. My actual code is more complex, and the difference in runtime is even bigger (around 8x slower). I hope this simple example works similar.


Solution

  • For general tips on running JAX microbenchmarks effectively, see FAQ: Benchmarking JAX code.

    I cannot reproduce the timings from your snippet, but in your more complicated case, I suspect you are getting fooled by JAX's Asynchronous dispatch, which means that the timing method you're using will not actually reflect the time taken by the underlying computation. To address this, you can wrap your results in jax.block_until_ready:

    start = time()
    vectorized_function(x, y).block_until_ready()
    t_1 = time() - start