pythonjupyter-notebookgoogle-colaboratoryjax

Colab, Jax, and GPU: why does cell execution take 60 seconds when %%timeit says it only takes 70 ms?


As the basis for a project on fractals, I'm trying to use GPU computation on Google Colab using the Jax library.

I'm using Mandelbrot on all accelerators as a model, and I'm encountering a problem.

When I use the %%timeit command to measure how long it takes to calculate my GPU function (same as in the model notebook), the times are entirely reasonable, and in line with expected results -- 70 to 80 ms.

But actually running %%timeit takes something like a full minute. (By default, it runs the function 7 times in a row and reports the average -- but even that should take less than a second.)

Similarly, when I run the function in a cell and output the results (a 6 megapixel image), it takes around 60 seconds for the cell to finish -- to execute a function that supposedly only takes 70-80 ms.

It seems like something is producing a massive amount of overhead, that also seems to scale with the amount of computation -- e.g. when the function contains 1,000 iterative calculations %%timeit says it takes 71 ms while in reality it takes 60 seconds, but with just 20 iterations %%timeit says it takes 10 ms while in reality it takes about 10 seconds.

I am pasting the code below, but here is a link to the Colab notebook itself -- anyone can make a copy, connect to a "T4 GPU" instance, and run it themselves to see.

import math
import numpy as np
import matplotlib.pyplot as plt
import jax

assert len(jax.devices("gpu")) == 1

def run_jax_kernel(c, fractal):
    z = c
    for i in range(1000):
        z = z**2 + c
        diverged = jax.numpy.absolute(z) > 2
        diverging_now = diverged & (fractal == 1000)
        fractal = jax.numpy.where(diverging_now, i, fractal)
    return fractal

run_jax_gpu_kernel = jax.jit(run_jax_kernel, backend="gpu")

def run_jax_gpu(height, width):

    mx = -0.69291874321833995150613818345974774914923989808007473759199
    my = 0.36963080032727980808623018005116209090839988898368679237704
    zw = 4 / 1e3

    y, x = jax.numpy.ogrid[(my-zw/2):(my+zw/2):height*1j, (mx-zw/2):(mx+zw/2):width*1j]
    c = x + y*1j
    fractal = jax.numpy.full(c.shape, 1000, dtype=np.int32)
    return np.asarray(run_jax_gpu_kernel(c, fractal).block_until_ready())

Takes about a minute to produce an image:

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
ax.imshow(run_jax_gpu(2000, 3000));

Takes about a minute to report that the function only takes 70-80 ms to execute:

%%timeit -o
run_jax_gpu(2000, 3000)

Solution

  • The first thing to realize is that %timeit will execute your code multiple times, and then return an average of the times for each run. The number of times it will execute is determined dynamically by the time of the first run.

    The second thing to realize is that JAX code is just-in-time (JIT) compiled, meaning that on the first execution of any particular function, you will incur a one-time compilation cost. Many things affect compilation cost, but functions that use large for loops (say, 1000 or more repetitions) tend to compile very slowly, because JAX unrolls those loops before passing the operations to XLA for compilation, and XLA compilation scales approximately quadratically with the number of unrolled operations (there is some discussion of this at JAX Sharp Bits: Control Flow).

    Put these together, and you'll see why you're observing the timings that you are: under %timeit, your first run results in a very long compilation, and subsequent runs are very fast. The resulting average time is printed, and is very short compared to the first run, and to the overall time.

    When you run your code a single time to plot the results, you are mainly seeing the compilation time. Because it is not amortized away by multiple calls to your function, that compilation time is long.

    The solution would be to avoid writing Python for loops in your function in order to avoid the long compilation time: one possibility would be to use lax.fori_loop, which allows you to write iterative computations without the huge compilation time penalty, though it will incur a runtime penalty on GPU compared to the for loop solution because the operations are executed sequentially rather than being parallelized by the compiler. In your case it might look like this:

    def run_jax_kernel(c, fractal):
        z = c
        def body_fun(i, carry):
            z, fractal = carry
            z = z**2 + c
            diverged = jax.numpy.absolute(z) > 2
            diverging_now = diverged & (fractal == 1000)
            fractal = jax.numpy.where(diverging_now, i, fractal)
            return (z, fractal)
        z, fractal = jax.lax.fori_loop(0, 1000, body_fun, (z, fractal))
        return fractal