Due to the novelty of Flax, NNX, and JAX, there’s not a lot of resources available. I’m running into the following peculiarity:
x = jnp.random.normal((1,512), key=KEY)
layer = nnx.Linear(512, 512, rngs=nnx.Rngs(KEY))
y1 = layer(x)
y2 = layer.kernel@x.squeeze() + layer.bias
print(y1==y2) # returns all False
My understanding is that matrix multiplication should be identical to a linear / fully connected layer. The discrepancy demonstrated here hinders the inspection of certain behavior (and the implementation of invertible dense layers using jnp.tensorsolve
).
Does anyone know what causes this discrepancy?
The matmul should be transposed; also floating point equality checks should be done via approximate rather than exact comparison, because different ways of computing the same result may lead to different floating point rounding errors:
import jax
from flax import nnx
KEY = jax.random.key(0)
x = jax.random.normal(KEY, (1,512))
layer = nnx.Linear(512, 512, rngs=nnx.Rngs(KEY))
y1 = layer(x)
y2 = x @ layer.kernel + layer.bias
print(jax.numpy.allclose(y1, y2)) # True