pythonlstmjaxflax

How to use FLAX LSTM in 2023


I am wondering if anyone here knows how to get FLAX LSTM layers to work in 2023. I have tried some of the code snippets on the actual Flax documentation, such as:

https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html

and, the first example provided there,

import flax.linen as nn
import jax
import jax.numpy as jnp

class LSTM(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x):
    ScanLSTM = nn.scan(
      nn.LSTMCell, variable_broadcast="params",
      split_rngs={"params": False}, in_axes=1, out_axes=1)

    lstm = ScanLSTM(self.features)
    input_shape =  x[:, 0].shape
    carry = lstm.initialize_carry(jax.random.key(0), input_shape)
    carry, x = lstm(carry, x)
    return x

x = jnp.ones((4, 12, 7))
module = LSTM(features=32)
y, variables = module.init_with_output(jax.random.key(0), x)

throws an error. I have looked for other examples but it seems they have changed their API at some point in 2023, so what I could find online wasn't working anymore.

In short, what I am looking for is a simple example on how to pass a time series into an LSTM in FLAX.

Thank you for your help.


Solution

  • The snippet you provided runs correctly with the most recent version of flax (version 0.7.4). If you're using an older version of flax, you should change jax.random.key to jax.random.PRNGKey. For some information about this JAX PRNG key change, see JEP 9263: Typed Keys and Pluggable PRNGs.