lstmrecurrent-neural-networkjitjaxflax

jax and flax not playing nicely with each other


I want to implement a neural network with multiple LSTM gates stacked one after the other.I set the hidden states to 0, as suggested here. When I try to run the code, I get

JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

When I try to replace jax.lax.scan by flax.linen.scan, it gives another error. Not quite sure how to proceed or what's actually going wrong here. Code attached below. Thanks!

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence


class LSTMModel(nn.Module):
lstm_hidden_size: int
num_lstm_layers: int
linear_layer_sizes: Sequence[int]
mean_aggregation: bool

def initialize_carry(self, batch_size, feature_size=1):
    """Initialize carry states with zeros for all LSTM layers."""
    return [
        (
            # Hidden state (h)
            jnp.zeros((batch_size, self.lstm_hidden_size)),
            # Cell state (c)
            jnp.zeros((batch_size, self.lstm_hidden_size)),
        )
        for _ in range(self.num_lstm_layers)
    ]

@nn.compact
def __call__(self, x, carry=None):
    if carry is None:
        raise ValueError(
            "Carry must be initialized explicitly using `initialize_carry`."
        )

    # Expand 2D input to 3D (if necessary)
    if x.ndim == 2:
        # [batch_size, sequence_length] -> [batch_size, sequence_length, 1]
        x = jnp.expand_dims(x, axis=-1)

    # Process through LSTM layers
    for i in range(self.num_lstm_layers):
        lstm_cell = nn.LSTMCell(
            features=self.lstm_hidden_size, name=f'lstm_cell_{i}')

        def step_fn(carry, xt):
            new_carry, yt = lstm_cell(carry, xt)
            return new_carry, yt

        # Use lax.scan to process the sequence
        carry[i], outputs = jax.lax.scan(step_fn, carry[i], x)
        x = outputs  # Update x for the next layer

    # Aggregate outputs
    if self.mean_aggregation:
        x = jnp.mean(x, axis=1)  # Average over the sequence
    else:
        x = x[:, -1, :]  # Use the last output

    # Pass through linear layers
    for size in self.linear_layer_sizes:
        x = nn.Dense(features=size)(x)
        x = nn.elu(x)

    # Final output layer
    x = nn.Dense(features=1)(x)
    return x


# Model hyperparameters
lstm_hidden_size = 64
num_lstm_layers = 2
linear_layer_sizes = [32, 16]
mean_aggregation = False

# Initialize model
model = LSTMModel(
    lstm_hidden_size=lstm_hidden_size,
    num_lstm_layers=num_lstm_layers,
    linear_layer_sizes=linear_layer_sizes,
    mean_aggregation=mean_aggregation
)

# Dummy input: batch of sequences with 10 timesteps
key = jax.random.PRNGKey(0)
# [batch_size, sequence_length, feature_size]
dummy_input = jax.random.normal(key, (32, 10, 1))

# Initialize carry states
carry = model.initialize_carry(
    batch_size=dummy_input.shape[0], feature_size=dummy_input.shape[-1])

# Initialize parameters
params = model.init(key, dummy_input, carry)

# Apply the model
outputs = model.apply(params, dummy_input, carry)

# Should print: [batch_size, 1]
print("Model output shape:", outputs.shape)

Solution

  • Consider using nn.RNN to simplify your code:

    lstm = nn.RNN(
      nn.LSTMCell(features=self.lstm_hidden_size),
      name=f'lstm_cell_{i}'
    )
    outputs = lstm(x)
    

    RNN will handle the carries for you. If you really want to handle the carries yourself you could use return_carry and initial_carry:

    lstm = nn.RNN(
      nn.LSTMCell(features=self.lstm_hidden_size),
      return_carry=True, 
      name=f'lstm_cell_{i}'
    )
    carry[i], outputs = lstm(x, initial_carry=carry[i])