I have the following training pipeline in JAX and Equinox. I want to pass a batch index to the loss function in order to apply different logic depending on index. Without batch index training loop works for about 15 sec, but if I pass an index, then it slows down for about an hour. Could you explain, why this happens? I'm new to JAX, sorry.
def fit_cv(model: eqx.Module,
dataloader: jdl.DataLoader,
optimizer: optax.GradientTransformation,
loss: tp.Callable,
n_steps: int = 1000):
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
dloss = eqx.filter_jit(eqx.filter_value_and_grad(loss))
@eqx.filter_jit
def step(model, data, opt_state, batch_index):
loss_score, grads = dloss(model, data, batch_index)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_score
loss_history = []
for batch_index, batch in tqdm(zip(range(n_steps), dataloader), total=n_steps):
if batch_index >= n_steps:
break
batch = batch[0] # dataloader returns tuple of size (1,)
model, opt_state, loss_score = step(model, batch, opt_state, batch_index)
loss_history.append(loss_score)
return model, loss_history
Loss function has the following signature
def loss(self, model: eqx.Module, data: jnp.ndarray, batch_index: int):
In particular, I want to switch between two loss functions after N steps. So, probably, I need to know the concrete value of a batch index.
Solution:
To use jax.lax.cond
condition = (batch_index // self.switch_steps) % 2 == 1
...
loss_value = jax.lax.cond(
jnp.all(condition),
lambda: loss1(inputs),
lambda: loss2(inputs),
)
return loss_value
I suspect the issue is excessive recompilation. You are using filter_jit
, which according to the docs has the following property:
All JAX and NumPy arrays are traced, and all other types are held static.
Each time a static argument to a JIT-compiled function changes, it triggers a re-compilation. This means that if batch_index
is a Python int
, then each time you call your function with a new value, the function will be recompiled.
As a fix, I would recommend using regular old jax.jit
, which requires you to explicitly specify static arguments, instead of the function trying to make the choice for you (potential surprises like this are one of the reasons why JAX has made this design choice - as the Zen of Python says, explicit is better than implicit). If you use jax.jit
and don't mark batch_index
as static, you shouldn't see this recompilation penalty.
Alternatively, if you want to keep using filter_jit
, then you could change your step call to this:
step(model, batch, opt_state, jnp.asarray(batch_index))
With this change, filter_jit
will no longer decide to make the batch index static. Of course, either of these suggestions would require that that loss
is compatible with dynamic batch_index
, which can't be determined from the information included in your question.