jaxflaxnnx

Does vmap correctly split the RNG keys?


In the following code, when I remove the vmap, I have the right randomized behavior. However, with vmap, I don't anymore. Isn't this supposed to be one of the features of nnx.vmap?

import jax
import jax.numpy as jnp
from flax import nnx

# --- 1. Define a Simple Model with a Stateful Layer (Dropout) ---
# We use nnx.Dropout because it requires random numbers, making it a stateful
# operation that benefits from nnx.vmap's automatic RNG splitting.

class SimpleDropoutModel(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    """Intializes the model."""
    # The dropout layer needs an RNG stream to generate random masks.
    self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
    self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)

  def __call__(self, x: jnp.ndarray, *, train: bool) -> jnp.ndarray:
    """Applies the model to a single input."""
    # The `deterministic` flag controls whether dropout is active.
    # We pass `not train` to it.
    x = self.linear(x)
    x = self.dropout(x, deterministic=not train)
    return x

# --- 2. Initialization ---
# Create a PRNG key for reproducibility.
key = jax.random.PRNGKey(42)

# Instantiate the model. NNX requires an `nnx.Rngs` object to manage
# different random number streams (e.g., for 'params' and 'dropout').
# We need to provide an RNG stream for 'params' as well for the Linear layer.
model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))

print("Model initialized successfully.")
print("Dropout Rate:", model.dropout.rate)
print("-" * 30)


# --- 3. Define and Transform the Batched Apply Function ---
# We want to apply our model to a whole batch of data.
# We compose nnx.vmap and nnx.jit to create an efficient, batched function.

# Define a helper function that takes the model, inputs, and train flag.
# Apply nnx.vmap and nnx.jit as decorators.
# Apply vmap first, then jit.
@nnx.vmap(
    in_axes=(None, 0, None), # model is not vmapped, x is vmapped, train is not vmapped
    out_axes=0 # Output is vmapped
)
@nnx.jit(static_argnames=["train"])
def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray, train: bool):
  """Applies the model to a batch of inputs."""
  # NNX will handle the state and RNGs of the model instance passed to this function.
  return model(x, train=train)


# --- 4. Run the Demonstration ---
# Create a dummy batch of 4 identical inputs. Each input is a vector of 10 ones.
batch_input = jnp.ones((4, 10))

print(f"Input batch shape: {batch_input.shape}")
print("Input batch:")
print(batch_input)
print("-" * 30)
print("Running the batched model in training mode (dropout is active)...")

# Run the JIT-compiled, vmapped function.
# Pass the model instance as the first argument. NNX will handle its state and RNGs.
output_batch = batched_apply(model, batch_input, train=True)

print(f"Output batch shape: {output_batch.shape}\n")
print("Output batch:")
print(output_batch)
print("-" * 30)

# --- 5. Verification ---
# Because dropout is random and nnx.vmap correctly split the RNG keys,
# each row in the output batch should be different, even though the inputs were identical.
# We verify that not all outputs are the same.
first_output = output_batch[0]
all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))

if not all_same:
    print("✅ Verification successful: The outputs are different for each sample in the batch.")
    print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
else:
    print("❌ Verification failed: All outputs were the same.")

Solution

  • To make dropout work together with vmap in flax, we need to use split_rngs and StateAxes :

    import jax
    import jax.numpy as jnp
    from flax import nnx
    
    # --- 1. Define a Simple Model with a Stateful Layer (Dropout) ---
    # We use nnx.Dropout because it requires random numbers, making it a stateful
    # operation that benefits from nnx.vmap's automatic RNG splitting.
    
    class SimpleDropoutModel(nnx.Module):
      def __init__(self, *, rngs: nnx.Rngs):
        """Intializes the model."""
        # The dropout layer needs an RNG stream to generate random masks.
        self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
        self.linear = nnx.Linear(in_features=10, out_features=5, rngs=rngs)
    
      def __call__(self, x: jnp.ndarray, *, train: bool) -> jnp.ndarray:
        """Applies the model to a single input."""
        # The `deterministic` flag controls whether dropout is active.
        # We pass `not train` to it.
        x = self.linear(x)
        x = self.dropout(x, deterministic=not train)
        return x
    
    # --- 2. Initialization ---
    # Create a PRNG key for reproducibility.
    key = jax.random.PRNGKey(42)
    
    # Instantiate the model. NNX requires an `nnx.Rngs` object to manage
    # different random number streams (e.g., for 'params' and 'dropout').
    # We need to provide an RNG stream for 'params' as well for the Linear layer.
    model = SimpleDropoutModel(rngs=nnx.Rngs(params=key, dropout=key))
    
    print("Model initialized successfully.")
    print("Dropout Rate:", model.dropout.rate)
    print("-" * 30)
    
    
    # --- 3. Define and Transform the Batched Apply Function ---
    # We want to apply our model to a whole batch of data.
    # We compose nnx.vmap and nnx.jit to create an efficient, batched function.
    
    # Define a helper function that takes the model, inputs, and train flag.
    # Apply nnx.vmap and nnx.jit as decorators.
    # Apply vmap first, then jit.
    bs = 4
    
    state_axes = nnx.StateAxes({'dropout': 0, ...: None})
    
    @nnx.split_rngs(splits=bs, only='dropout')
    @nnx.vmap(
        in_axes=(state_axes, 0, None), # model is not vmapped, x is vmapped, train is not vmapped
        out_axes=0 # Output is vmapped
    )
    @nnx.jit(static_argnames=["train"])
    def batched_apply(model: SimpleDropoutModel, x: jnp.ndarray, train: bool):
      """Applies the model to a batch of inputs."""
      # NNX will handle the state and RNGs of the model instance passed to this function.
      return model(x, train=train)
    
    
    # --- 4. Run the Demonstration ---
    # Create a dummy batch of 4 identical inputs. Each input is a vector of 10 ones.
    batch_input = jnp.ones((bs, 10))
    
    print(f"Input batch shape: {batch_input.shape}")
    print("Input batch:")
    print(batch_input)
    print("-" * 30)
    print("Running the batched model in training mode (dropout is active)...")
    
    model.train()
    
    # Run the JIT-compiled, vmapped function.
    # Pass the model instance as the first argument. NNX will handle its state and RNGs.
    output_batch = batched_apply(model, batch_input, train=True)
    
    print(f"Output batch shape: {output_batch.shape}\n")
    print("Output batch:")
    print(output_batch)
    print("-" * 30)
    
    # --- 5. Verification ---
    # Because dropout is random and nnx.vmap correctly split the RNG keys,
    # each row in the output batch should be different, even though the inputs were identical.
    # We verify that not all outputs are the same.
    first_output = output_batch[0]
    all_same = jnp.all(jnp.all(output_batch == first_output, axis=1))
    
    if not all_same:
        print("✅ Verification successful: The outputs are different for each sample in the batch.")
        print("This proves nnx.vmap correctly split the 'dropout' RNG stream.")
    else:
        print("❌ Verification failed: All outputs were the same.")
    
    
    

    Output with jax: 0.7.0.dev20250704, flax: 0.10.6

    Output batch:
    [[0.         0.1736668  1.6533196  0.         0.        ]
     [0.         0.         1.6533196  0.         0.7218913 ]
     [0.09358063 0.         1.6533196  0.         0.7218913 ]
     [0.09358063 0.         1.6533196  0.         0.7218913 ]]
    ------------------------------
    ✅ Verification successful: The outputs are different for each sample in the batch.
    This proves nnx.vmap correctly split the 'dropout' RNG stream.