matrix-multiplicationjaxflax

Why is Flax Linear layer not identical to matrix multiplication?


Due to the novelty of Flax, NNX, and JAX, there’s not a lot of resources available. I’m running into the following peculiarity:

x = jnp.random.normal((1,512), key=KEY)
layer = nnx.Linear(512, 512, rngs=nnx.Rngs(KEY))
y1 = layer(x)
y2 = layer.kernel@x.squeeze() + layer.bias
print(y1==y2) # returns all False

My understanding is that matrix multiplication should be identical to a linear / fully connected layer. The discrepancy demonstrated here hinders the inspection of certain behavior (and the implementation of invertible dense layers using jnp.tensorsolve).

Does anyone know what causes this discrepancy?


Solution

  • The matmul should be transposed; also floating point equality checks should be done via approximate rather than exact comparison, because different ways of computing the same result may lead to different floating point rounding errors:

    import jax
    from flax import nnx
    
    KEY = jax.random.key(0)
    x = jax.random.normal(KEY, (1,512))
    layer = nnx.Linear(512, 512, rngs=nnx.Rngs(KEY))
    y1 = layer(x)
    y2 = x @ layer.kernel + layer.bias
    print(jax.numpy.allclose(y1, y2))  # True