jaxequinox

JAX/Equinox pipeline slows down after adding an integer argument to a loss function


I have the following training pipeline in JAX and Equinox. I want to pass a batch index to the loss function in order to apply different logic depending on index. Without batch index training loop works for about 15 sec, but if I pass an index, then it slows down for about an hour. Could you explain, why this happens? I'm new to JAX, sorry.

def fit_cv(model: eqx.Module, 
           dataloader: jdl.DataLoader, 
           optimizer: optax.GradientTransformation, 
           loss: tp.Callable, 
           n_steps: int = 1000):
    
    opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
    dloss = eqx.filter_jit(eqx.filter_value_and_grad(loss))
    
    @eqx.filter_jit
    def step(model, data, opt_state, batch_index):
        loss_score, grads = dloss(model, data, batch_index)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_score
    
    loss_history = []
    for batch_index, batch in tqdm(zip(range(n_steps), dataloader), total=n_steps):
        if batch_index >= n_steps:
            break
        batch = batch[0] # dataloader returns tuple of size (1,)
        model, opt_state, loss_score = step(model, batch, opt_state, batch_index)
        loss_history.append(loss_score)
    return model, loss_history

Loss function has the following signature

def loss(self, model: eqx.Module, data: jnp.ndarray, batch_index: int):

In particular, I want to switch between two loss functions after N steps. So, probably, I need to know the concrete value of a batch index.

Solution:

To use jax.lax.cond

        condition = (batch_index // self.switch_steps) % 2 == 1
        ...
        loss_value = jax.lax.cond(
            jnp.all(condition),
            lambda: loss1(inputs),
            lambda: loss2(inputs),
        )
        return loss_value

Solution

  • I suspect the issue is excessive recompilation. You are using filter_jit, which according to the docs has the following property:

    All JAX and NumPy arrays are traced, and all other types are held static.

    Each time a static argument to a JIT-compiled function changes, it triggers a re-compilation. This means that if batch_index is a Python int, then each time you call your function with a new value, the function will be recompiled.

    As a fix, I would recommend using regular old jax.jit, which requires you to explicitly specify static arguments, instead of the function trying to make the choice for you (potential surprises like this are one of the reasons why JAX has made this design choice - as the Zen of Python says, explicit is better than implicit). If you use jax.jit and don't mark batch_index as static, you shouldn't see this recompilation penalty.

    Alternatively, if you want to keep using filter_jit, then you could change your step call to this:

    step(model, batch, opt_state, jnp.asarray(batch_index))
    

    With this change, filter_jit will no longer decide to make the batch index static. Of course, either of these suggestions would require that that loss is compatible with dynamic batch_index, which can't be determined from the information included in your question.