pythonjaxequinoxcomputation-graph

TypeError: unhashable type: 'ArrayImpl' when trying to use Equinox module with jax.lax.scan


I'm new to Equinox and JAX but wanted to use them to simulate a dynamical system.

But when I pass my system model as an Equinox module to jax.lax.scan I get the unhashable type error in the title. I understand that jax expects the function argument to be a pure function but I thought an Equinox Module would emulate that.

Here is a test script to reproduce the error

import equinox as eqx
import jax
import jax.numpy as jnp


class EqxModel(eqx.Module):
    A: jax.Array
    B: jax.Array
    C: jax.Array
    D: jax.Array

    def __call__(self, states, inputs):
        x = states.reshape(-1, 1)
        u = inputs.reshape(-1, 1)
        x_next = self.A @ x + self.B @ u
        y = self.C @ x + self.D @ u
        return x_next.reshape(-1), y.reshape(-1)


def simulate(model, inputs, x0):
    xk = x0
    outputs = []
    for uk in inputs:
        xk, yk = model(xk, uk)
        outputs.append(yk)
    outputs = jnp.stack(outputs)
    return xk, outputs


A = jnp.array([[0.7, 1.0], [0.0, 1.0]])
B = jnp.array([[0.0], [1.0]])
C = jnp.array([[0.3, 0.0]])
D = jnp.array([[0.0]])
model = EqxModel(A, B, C, D)

# Test simulation
inputs = jnp.array([[0.0], [1.0], [1.0], [1.0]])
x0 = jnp.zeros(2)
xk, outputs = simulate(model, inputs, x0)
assert jnp.allclose(xk, jnp.array([2.7, 3.0]))
assert jnp.allclose(outputs, jnp.array([[0.0], [0.0], [0.0], [0.3]]))

# This raises TypeError
xk, outputs = jax.lax.scan(model, x0, inputs)

What is unhashable type: 'ArrayImpl' referring to? Is it the arrays A, B, C, and D? In this model, these matrices are parameters and therefore should be static for the duration of the simulation.

I just found this issue thread that might be related:


Solution

  • Owen Lockwood (lockwo) has provided an explanation and answer in this issue thread, which I will re-iterate below.

    I believe your issue is happening because jax tries to hash the function you are scanning over, but it can't hash the arrays that are in the module. There are probably a number of things that you could do to solve this, the simplest being to just curry the model, e.g. xk, outputs = jax.lax.scan(lambda carry, y: model(carry, y), x0, inputs) works fine

    Or, re-written in terms of the variable names I am using:

    xk, outputs = jax.lax.scan(lambda xk, uk: model(xk, uk), x0, inputs)