jax

how to log activation values using jax


I am following a jax tutorial that trains mnist using mlp network. I am trying to add an additional code that saves the activation patterns at every layer except the last. Here is the modified code:

from collections import defaultdict
# this is my activation pattern logger
class ActivationLogger:
    def __init__(self, epoch):
       self.reset(epoch)

    def __call__(self, layer, activations):
        D = activations.shape[0]
        for i in range(D):
            self.activations[(layer, i)].append(
                    jax.lax.stop_gradient(activations[i]))

    def reset(self, epoch):
        self.epoch = epoch
        self.activations = defaultdict(list)

activation_logger = ActivationLogger(epoch=1)

...

def predict(params, image):
    # per-example predictions
    activations = image
    for l, (w, b) in enumerate(params[:-1]):
        outputs = jnp.dot(w, activations) + b
        activations = jnp.maximum(0, outputs)
        activation_logger(l+1, activations) # <- this was added

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

batched_predict = jax.vmap(
        predict, 
        in_axes=(None, 0), 
        out_axes=0)

@jax.jit
def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

When I run my training code, I keep getting the following error message:

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].
This BatchTracer with object id 7541955024 was created on line:
  /var/folders/km/3nj8tmq56s16dsgc9_63530r0000gn/T/ipykernel_73750/3494160804.py:12:20 (ActivationLogger.__call__)

Any suggestions on how to fix this?


Solution

  • Python functions in JAX code are executed at trace-time, not runtime, and so as written you're not logging concrete runtime values, but rather their abstract trace-time representations.

    If you want to log runtime values, the best tool is probably jax.debug.callback; for info on using this, I'd suggest starting with External Callbacks in JAX.

    Using it in your case would look something like this:

        for l, (w, b) in enumerate(params[:-1]):
            outputs = jnp.dot(w, activations) + b
            activations = jnp.maximum(0, outputs)
            jax.debug.callback(activation_logger, l+1, activations)
    

    For more background on JAX's execution model, and why your function didn't work as expected when executed directly a trace-time, a good place to start is How to think in JAX.