I am new to JAX and trying to use it with PennyLane and optax to optimize a simple quantum circuit. However, I noticed that my print statement inside the cost function does not execute in every iteration. Specifically, it prints only once at the beginning and then stops appearing.
The quantum circuit itself does not make sense; I just wanted to simplify the example as much as possible. I believe the circuit is not actually relevant to the question, but it's included as an example.
Here is my code:
import pennylane as qml
import jax
import jax.numpy as jnp
import optax
jax.config.update("jax_enable_x64", True)
device = qml.device("default.qubit", wires=1)
@qml.qnode(device, interface='jax')
def circuit(params):
qml.RX(params, wires=0)
return qml.expval(qml.PauliZ(0))
def cost(params):
print('Evaluating')
return circuit(params)
# Define optimizer
params = jnp.array(0.1)
opt = optax.adam(learning_rate=0.1)
opt_state = opt.init(params)
# JIT the gradient function
grad = jax.jit(jax.grad(cost))
for epoch in range(5):
print(f'{epoch = }')
grad_value = grad(params)
updates, opt_state = opt.update(grad_value, opt_state)
params = optax.apply_updates(params, updates)
epoch = 0
Evaluating
epoch = 1
Evaluating
epoch = 2
Evaluating
epoch = 3
Evaluating
epoch = 4
Evaluating
epoch = 0
Evaluating
epoch = 1
Evaluating
epoch = 2
epoch = 3
epoch = 4
Why is the print statement inside cost not executed after the first iteration? Is JAX caching the function call or optimizing it in a way that skips execution? How can I ensure that cost is evaluated in every iteration?
When working with JAX it is important to understand the difference between "trace time" and "runtime". For JIT compilation JAX does an abstract evaluation of the function when it is called first. This is used to "trace" the computational graph of the function and then create a fully compiled replacement, which is cached and then invoked on the next calls ("runtime") of the function. Now, Python's print
statements are only evaluated at trace time and not at runtime, because the code of the function has been effectively replaced by a compiled version.
For the case of printing during runtime, JAX has a special jax.debug.print
function, you can use:
def cost(params):
jax.debug.print('Evaluating')
return circuit(params)
More on the jax.debug
utilities: https://docs.jax.dev/en/latest/debugging/index.html
And JIT compilation: https://docs.jax.dev/en/latest/jit-compilation.html