pythonjax

Why does JAX's grad not always print inside the cost function?


I am new to JAX and trying to use it with PennyLane and optax to optimize a simple quantum circuit. However, I noticed that my print statement inside the cost function does not execute in every iteration. Specifically, it prints only once at the beginning and then stops appearing.

The quantum circuit itself does not make sense; I just wanted to simplify the example as much as possible. I believe the circuit is not actually relevant to the question, but it's included as an example.

Here is my code:

import pennylane as qml
import jax
import jax.numpy as jnp
import optax

jax.config.update("jax_enable_x64", True)

device = qml.device("default.qubit", wires=1)


@qml.qnode(device, interface='jax')
def circuit(params):
    qml.RX(params, wires=0)
    return qml.expval(qml.PauliZ(0))

def cost(params):
    print('Evaluating')
    return circuit(params)

# Define optimizer
params = jnp.array(0.1)
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)

# JIT the gradient function
grad = jax.jit(jax.grad(cost))

for epoch in range(5):
    print(f'{epoch = }')
    grad_value = grad(params)
    updates, opt_state = opt.update(grad_value, opt_state)
    params = optax.apply_updates(params, updates)

Expected output:

epoch = 0
Evaluating
epoch = 1
Evaluating
epoch = 2
Evaluating
epoch = 3
Evaluating
epoch = 4
Evaluating

Actual output:

epoch = 0
Evaluating
epoch = 1
Evaluating
epoch = 2
epoch = 3
epoch = 4

Question:

Why is the print statement inside cost not executed after the first iteration? Is JAX caching the function call or optimizing it in a way that skips execution? How can I ensure that cost is evaluated in every iteration?


Solution

  • When working with JAX it is important to understand the difference between "trace time" and "runtime". For JIT compilation JAX does an abstract evaluation of the function when it is called first. This is used to "trace" the computational graph of the function and then create a fully compiled replacement, which is cached and then invoked on the next calls ("runtime") of the function. Now, Python's print statements are only evaluated at trace time and not at runtime, because the code of the function has been effectively replaced by a compiled version.

    For the case of printing during runtime, JAX has a special jax.debug.print function, you can use:

    def cost(params):
        jax.debug.print('Evaluating')
        return circuit(params)
    
    

    More on the jax.debug utilities: https://docs.jax.dev/en/latest/debugging/index.html

    And JIT compilation: https://docs.jax.dev/en/latest/jit-compilation.html