I want to implement a neural network with multiple LSTM gates stacked one after the other.I set the hidden states to 0, as suggested here. When I try to run the code, I get
JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
When I try to replace jax.lax.scan by flax.linen.scan, it gives another error. Not quite sure how to proceed or what's actually going wrong here. Code attached below. Thanks!
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence
class LSTMModel(nn.Module):
lstm_hidden_size: int
num_lstm_layers: int
linear_layer_sizes: Sequence[int]
mean_aggregation: bool
def initialize_carry(self, batch_size, feature_size=1):
"""Initialize carry states with zeros for all LSTM layers."""
return [
(
# Hidden state (h)
jnp.zeros((batch_size, self.lstm_hidden_size)),
# Cell state (c)
jnp.zeros((batch_size, self.lstm_hidden_size)),
)
for _ in range(self.num_lstm_layers)
]
@nn.compact
def __call__(self, x, carry=None):
if carry is None:
raise ValueError(
"Carry must be initialized explicitly using `initialize_carry`."
)
# Expand 2D input to 3D (if necessary)
if x.ndim == 2:
# [batch_size, sequence_length] -> [batch_size, sequence_length, 1]
x = jnp.expand_dims(x, axis=-1)
# Process through LSTM layers
for i in range(self.num_lstm_layers):
lstm_cell = nn.LSTMCell(
features=self.lstm_hidden_size, name=f'lstm_cell_{i}')
def step_fn(carry, xt):
new_carry, yt = lstm_cell(carry, xt)
return new_carry, yt
# Use lax.scan to process the sequence
carry[i], outputs = jax.lax.scan(step_fn, carry[i], x)
x = outputs # Update x for the next layer
# Aggregate outputs
if self.mean_aggregation:
x = jnp.mean(x, axis=1) # Average over the sequence
else:
x = x[:, -1, :] # Use the last output
# Pass through linear layers
for size in self.linear_layer_sizes:
x = nn.Dense(features=size)(x)
x = nn.elu(x)
# Final output layer
x = nn.Dense(features=1)(x)
return x
# Model hyperparameters
lstm_hidden_size = 64
num_lstm_layers = 2
linear_layer_sizes = [32, 16]
mean_aggregation = False
# Initialize model
model = LSTMModel(
lstm_hidden_size=lstm_hidden_size,
num_lstm_layers=num_lstm_layers,
linear_layer_sizes=linear_layer_sizes,
mean_aggregation=mean_aggregation
)
# Dummy input: batch of sequences with 10 timesteps
key = jax.random.PRNGKey(0)
# [batch_size, sequence_length, feature_size]
dummy_input = jax.random.normal(key, (32, 10, 1))
# Initialize carry states
carry = model.initialize_carry(
batch_size=dummy_input.shape[0], feature_size=dummy_input.shape[-1])
# Initialize parameters
params = model.init(key, dummy_input, carry)
# Apply the model
outputs = model.apply(params, dummy_input, carry)
# Should print: [batch_size, 1]
print("Model output shape:", outputs.shape)
Consider using nn.RNN to simplify your code:
lstm = nn.RNN(
nn.LSTMCell(features=self.lstm_hidden_size),
name=f'lstm_cell_{i}'
)
outputs = lstm(x)
RNN will handle the carries for you. If you really want to handle the carries yourself you could use return_carry
and initial_carry
:
lstm = nn.RNN(
nn.LSTMCell(features=self.lstm_hidden_size),
return_carry=True,
name=f'lstm_cell_{i}'
)
carry[i], outputs = lstm(x, initial_carry=carry[i])