In the following code, when I remove the vmap, I have the right randomized behavior. However, with vmap, I don't anymore. Isn't this supposed to be one of the features of nnx.vmap?
import jax
import jax.numpy as jnp
from flax import nnx
# --- 1. Define a Simple Model with a Stateful Layer (Dropout) ---
# We use nnx.Dropout because it requires random numbers, making it a stateful
# operation that benefits from nnx.vmap's automatic RNG splitting.
class SimpleDropoutModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
"""Intializes the model."""
# The dropout layer needs an RNG stream to generate random masks.
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
def __call__(self, x: jnp.ndarray, *, train: bool) -> jnp.ndarray:
"""Applies the model to a single input."""
# The `deterministic` flag controls whether dropout is active.
# We pass `not train` to it.
x = self.linear(x)
x = self.dropout(x, deterministic=not train)
return x
# --- 2. Initialization ---
# Create a PRNG key for reproducibility.
key = jax.random.PRNGKey(42)
# Instantiate the model. NNX requires an `nnx.Rngs` object to manage
# different random number streams (e.g., for 'params' and 'dropout').
# We need to provide an RNG stream for 'params' as well for the Linear layer.
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
# --- 3. Define and Transform the Batched Apply Function ---
# We want to apply our model to a whole batch of data.
# We compose nnx.vmap and nnx.jit to create an efficient, batched function.
# Define a helper function that takes the model, inputs, and train flag.
# Apply nnx.vmap and nnx.jit as decorators.
# Apply vmap first, then jit.
@nnx.vmap(
in_axes=(None, 0, None), # model is not vmapped, x is vmapped, train is not vmapped
out_axes=0 # Output is vmapped
)
@nnx.jit(static_argnames=["train"])
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray, train: bool):
"""Applies the model to a batch of inputs."""
# NNX will handle the state and RNGs of the model instance passed to this function.
return model(x, train=train)
# --- 4. Run the Demonstration ---
# Create a dummy batch of 4 identical inputs. Each input is a vector of 10 ones.
batch_input = jnp.ones((4, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
# Run the JIT-compiled, vmapped function.
# Pass the model instance as the first argument. NNX will handle its state and RNGs.
output_batch = batched_apply(model, batch_input, train=True)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
# --- 5. Verification ---
# Because dropout is random and nnx.vmap correctly split the RNG keys,
# each row in the output batch should be different, even though the inputs were identical.
# We verify that not all outputs are the same.
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
print("✅ Verification successful: The outputs are different for each sample in the batch.")
print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
print("❌ Verification failed: All outputs were the same.")
To make dropout work together with vmap in flax, we need to use split_rngs
and StateAxes
:
import jax
import jax.numpy as jnp
from flax import nnx
# --- 1. Define a Simple Model with a Stateful Layer (Dropout) ---
# We use nnx.Dropout because it requires random numbers, making it a stateful
# operation that benefits from nnx.vmap's automatic RNG splitting.
class SimpleDropoutModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
"""Intializes the model."""
# The dropout layer needs an RNG stream to generate random masks.
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
def __call__(self, x: jnp.ndarray, *, train: bool) -> jnp.ndarray:
"""Applies the model to a single input."""
# The `deterministic` flag controls whether dropout is active.
# We pass `not train` to it.
x = self.linear(x)
x = self.dropout(x, deterministic=not train)
return x
# --- 2. Initialization ---
# Create a PRNG key for reproducibility.
key = jax.random.PRNGKey(42)
# Instantiate the model. NNX requires an `nnx.Rngs` object to manage
# different random number streams (e.g., for 'params' and 'dropout').
# We need to provide an RNG stream for 'params' as well for the Linear layer.
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)
# --- 3. Define and Transform the Batched Apply Function ---
# We want to apply our model to a whole batch of data.
# We compose nnx.vmap and nnx.jit to create an efficient, batched function.
# Define a helper function that takes the model, inputs, and train flag.
# Apply nnx.vmap and nnx.jit as decorators.
# Apply vmap first, then jit.
bs = 4
state_axes = nnx.StateAxes({'dropout': 0, ...: None})
@nnx.split_rngs(splits=bs, only='dropout')
@nnx.vmap(
in_axes=(state_axes, 0, None), # model is not vmapped, x is vmapped, train is not vmapped
out_axes=0 # Output is vmapped
)
@nnx.jit(static_argnames=["train"])
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray, train: bool):
"""Applies the model to a batch of inputs."""
# NNX will handle the state and RNGs of the model instance passed to this function.
return model(x, train=train)
# --- 4. Run the Demonstration ---
# Create a dummy batch of 4 identical inputs. Each input is a vector of 10 ones.
batch_input = jnp.ones((bs, 10))
print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")
model.train()
# Run the JIT-compiled, vmapped function.
# Pass the model instance as the first argument. NNX will handle its state and RNGs.
output_batch = batched_apply(model, batch_input, train=True)
print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)
# --- 5. Verification ---
# Because dropout is random and nnx.vmap correctly split the RNG keys,
# each row in the output batch should be different, even though the inputs were identical.
# We verify that not all outputs are the same.
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
if not all_same:
print("✅ Verification successful: The outputs are different for each sample in the batch.")
print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
print("❌ Verification failed: All outputs were the same.")
Output with jax: 0.7.0.dev20250704, flax: 0.10.6
Output batch:
[[0. 0.1736668 1.6533196 0. 0. ]
[0. 0. 1.6533196 0. 0.7218913 ]
[0.09358063 0. 1.6533196 0. 0.7218913 ]
[0.09358063 0. 1.6533196 0. 0.7218913 ]]
------------------------------
✅ Verification successful: The outputs are different for each sample in the batch.
This proves nnx.vmap correctly split the 'dropout' RNG stream.