pythonjupyter-notebookgoogle-colaboratoryjax# Colab, Jax, and GPU: why does cell execution take 60 seconds when %%timeit says it only takes 70 ms?

As the basis for a project on fractals, I'm trying to use GPU computation on Google Colab using the Jax library.

I'm using Mandelbrot on all accelerators as a model, and I'm encountering a problem.

When I use the `%%timeit`

command to measure how long it takes to calculate my GPU function (same as in the model notebook), the times are entirely reasonable, and in line with expected results -- **70 to 80 ms**.

But actually *running* `%%timeit`

takes something like **a full minute**. (By default, it runs the function 7 times in a row and reports the average -- but even that should take less than a second.)

Similarly, when I run the function in a cell and output the results (a 6 megapixel image), it takes around 60 seconds for the cell to finish -- to execute a function that supposedly only takes 70-80 ms.

It seems like something is producing a massive amount of overhead, that also seems to scale with the amount of computation -- e.g. when the function contains 1,000 iterative calculations `%%timeit`

says it takes 71 ms while in reality it takes 60 seconds, but with just 20 iterations `%%timeit`

says it takes 10 ms while in reality it takes about 10 seconds.

I am pasting the code below, but here is a link to the Colab notebook itself -- anyone can make a copy, connect to a "T4 GPU" instance, and run it themselves to see.

```
import math
import numpy as np
import matplotlib.pyplot as plt
import jax
assert len(jax.devices("gpu")) == 1
def run_jax_kernel(c, fractal):
z = c
for i in range(1000):
z = z**2 + c
diverged = jax.numpy.absolute(z) > 2
diverging_now = diverged & (fractal == 1000)
fractal = jax.numpy.where(diverging_now, i, fractal)
return fractal
run_jax_gpu_kernel = jax.jit(run_jax_kernel, backend="gpu")
def run_jax_gpu(height, width):
mx = -0.69291874321833995150613818345974774914923989808007473759199
my = 0.36963080032727980808623018005116209090839988898368679237704
zw = 4 / 1e3
y, x = jax.numpy.ogrid[(my-zw/2):(my+zw/2):height*1j, (mx-zw/2):(mx+zw/2):width*1j]
c = x + y*1j
fractal = jax.numpy.full(c.shape, 1000, dtype=np.int32)
return np.asarray(run_jax_gpu_kernel(c, fractal).block_until_ready())
```

Takes about a minute to produce an image:

```
fig, ax = plt.subplots(1, 1, figsize=(15, 10))
ax.imshow(run_jax_gpu(2000, 3000));
```

Takes about a minute to report that the function only takes 70-80 ms to execute:

```
%%timeit -o
run_jax_gpu(2000, 3000)
```

Solution

The first thing to realize is that `%timeit`

will execute your code multiple times, and then return an average of the times for each run. The number of times it will execute is determined dynamically by the time of the first run.

The second thing to realize is that JAX code is just-in-time (JIT) compiled, meaning that on the first execution of any particular function, you will incur a one-time compilation cost. Many things affect compilation cost, but functions that use large `for`

loops (say, 1000 or more repetitions) tend to compile very slowly, because JAX unrolls those loops before passing the operations to XLA for compilation, and XLA compilation scales approximately quadratically with the number of unrolled operations (there is some discussion of this at JAX Sharp Bits: Control Flow).

Put these together, and you'll see why you're observing the timings that you are: under `%timeit`

, your first run results in a very long compilation, and subsequent runs are very fast. The resulting average time is printed, and is very short compared to the first run, and to the overall time.

When you run your code a single time to plot the results, you are mainly seeing the compilation time. Because it is not amortized away by multiple calls to your function, that compilation time is long.

The solution would be to avoid writing Python `for`

loops in your function in order to avoid the long compilation time: one possibility would be to use `lax.fori_loop`

, which allows you to write iterative computations without the huge compilation time penalty, though it will incur a runtime penalty on GPU compared to the `for`

loop solution because the operations are executed sequentially rather than being parallelized by the compiler. In your case it might look like this:

```
def run_jax_kernel(c, fractal):
z = c
def body_fun(i, carry):
z, fractal = carry
z = z**2 + c
diverged = jax.numpy.absolute(z) > 2
diverging_now = diverged & (fractal == 1000)
fractal = jax.numpy.where(diverging_now, i, fractal)
return (z, fractal)
z, fractal = jax.lax.fori_loop(0, 1000, body_fun, (z, fractal))
return fractal
```

- AttributeError: install_layout when attempting to install a package in a virtual environment
- Python list comprehension - want to avoid repeated evaluation
- Hash algorithm for dynamic growing/streaming data?
- matplotlib - making labels for violin plots
- Python How to I check if last element has been reached in iterator tool chain?
- Polars and the Lazy API: How to drop columns that contain only null values?
- Why are my Mean, Var, and Std outputs from NumPy different from what the online grader expects?
- Correlation dataframe convertion from results from pl.corr
- Polars DataFrame transformation
- Discord rate limiting while only sending 1 request per minute
- Check if column contains (/,-,_, *or~) and split in another column - Pandas
- How to draw a rectangle at (x,y) in a PyQt GraphicsView?
- how to calculate correlation between ten columns with polars
- How to set class attribute with await in __init__
- Detect hindi encoding, response received from Facebook API in Python
- Is it possible to write a horizontal if statement with a multi-line body?
- Max length of items in list
- Cannot subclass multiprocessing Queue in Python 3.5
- How can I get notified of updates to Python packages in a unified way?
- Using python AST to traverse code and extract return statements
- merge groups of columns in a polars dataframe to single columns
- Group Pandas DataFrame by Continuous Date Ranges
- Flask login @login_required not working
- Odoo: one2many and many2one? KeyError:'___'
- merge some columns in a Polars dataframe and duplicate the others
- Python: Create table from string mixed with separators using FOR loops
- How do I type hint a method with the type of the enclosing class?
- How can I verify an emails DKIM signature in Python?
- Writing a class that accepts a callback in Python?
- Python Paramiko channel.exec_command not returning output intermittently