Can someone explain the following behaviour? Is it a bug?
from jax import grad
import jax.numpy as jnp
x = jnp.ones(2)
grad(lambda v: jnp.linalg.norm(v-v))(x) # returns DeviceArray([nan, nan], dtype=float32)
grad(lambda v: jnp.linalg.norm(0))(x) # returns DeviceArray([0., 0.], dtype=float32)
I've tried looking up the error online but didn't find anything relevant.
I also skimmed through https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
When you compute grad(lambda v: jnp.linalg.norm(v-v))(x)
, your function looks roughly like this:
f(x) = sqrt[(x - x)^2]
so, evaluating with the chain rule, the derivative is
df/dx = (x - x) / sqrt[(x - x)^2]
which, when you plug-in any finite x
evaluates to
0 / sqrt(0)
which is undefined, and represented by NaN
in floating point arithmetic.
When you compute grad(lambda v: jnp.linalg.norm(0))(x)
, your function looks roughly like this:
g(x) = sqrt[0.0^2]
and because it has no dependence on x
the derivative is simply
dg/dx = 0.0
Does that answer your question?