I am following a jax tutorial that trains mnist using mlp network. I am trying to add an additional code that saves the activation patterns at every layer except the last. Here is the modified code:
from collections import defaultdict
# this is my activation pattern logger
class ActivationLogger:
def __init__(self, epoch):
self.reset(epoch)
def __call__(self, layer, activations):
D = activations.shape[0]
for i in range(D):
self.activations[(layer, i)].append(
jax.lax.stop_gradient(activations[i]))
def reset(self, epoch):
self.epoch = epoch
self.activations = defaultdict(list)
activation_logger = ActivationLogger(epoch=1)
...
def predict(params, image):
# per-example predictions
activations = image
for l, (w, b) in enumerate(params[:-1]):
outputs = jnp.dot(w, activations) + b
activations = jnp.maximum(0, outputs)
activation_logger(l+1, activations) # <- this was added
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
batched_predict = jax.vmap(
predict,
in_axes=(None, 0),
out_axes=0)
@jax.jit
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
When I run my training code, I keep getting the following error message:
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].
This BatchTracer with object id 7541955024 was created on line:
/var/folders/km/3nj8tmq56s16dsgc9_63530r0000gn/T/ipykernel_73750/3494160804.py:12:20 (ActivationLogger.__call__)
Any suggestions on how to fix this?
Python functions in JAX code are executed at trace-time, not runtime, and so as written you're not logging concrete runtime values, but rather their abstract trace-time representations.
If you want to log runtime values, the best tool is probably jax.debug.callback
; for info on using this, I'd suggest starting with External Callbacks in JAX.
Using it in your case would look something like this:
for l, (w, b) in enumerate(params[:-1]):
outputs = jnp.dot(w, activations) + b
activations = jnp.maximum(0, outputs)
jax.debug.callback(activation_logger, l+1, activations)
For more background on JAX's execution model, and why your function didn't work as expected when executed directly a trace-time, a good place to start is How to think in JAX.