I'm trying to implement this function and use JAX to automatically build the gradient function:
$f(x) = \sum\limits_{k=1}^{n-1} [100 (x_{k+1} - x_k^2)^2 + (1 - x_k)^2]$
(sorry, I don't know how to format math on stackoverflow. Some sister sites allow TeX, but apparently this site does not?)
import jax
import jax.numpy as jnp
# x is an array, which does not handle type hints well.
def rosenbrock(n: int, x: any) -> float:
f = 0
# i is 1-indexed to match document.
for i in range(1, n):
# adjust 1-based indices to 0-based python indices.
xi = x[i-1].item()
xip1 = x[i].item()
fi = 100 * (xip1 - xi**2)**2 + (1 - xi)**2
f = f + fi
return f
# with n=2.
def rosenbrock2(x: any) -> float:
return rosenbrock(2, x)
grad_rosenbrock2 = jax.grad(rosenbrock2)
x = jnp.array([-1.2, 1], dtype=jnp.float32).reshape(2,1)
# this line fails with the error given below
grad_rosenbrock2(x)
This last line results in:
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[1].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
I'm trying to follow the docs, and I'm confused. This is my first time using JAX or Autograd, can someone help me resolve this? Thanks!
The problem is that the .item()
method attempts to convert an array to a static Python scalar, and since you have traced arrays within your grad
transformation, conversion to a static value is not possible.
What you need here is to convert a size-1 array to a scalar array, which you can do using .reshape(())
:
def rosenbrock(n: int, x: any) -> float:
f = 0
# i is 1-indexed to match document.
for i in range(1, n):
# adjust 1-based indices to 0-based python indices.
xi = x[i-1].reshape(())
xip1 = x[i].reshape(())
fi = 100 * (xip1 - xi**2)**2 + (1 - xi)**2
f = f + fi
return f
For more background on jax transformations and traced arrays, I'd recommend How to think in JAX.