pythonnumpyjaxautograd

JAX jax.grad on simple function that takes an array: `ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected`


I'm trying to implement this function and use JAX to automatically build the gradient function:

$f(x) = \sum\limits_{k=1}^{n-1} [100 (x_{k+1} - x_k^2)^2 + (1 - x_k)^2]$

(sorry, I don't know how to format math on stackoverflow. Some sister sites allow TeX, but apparently this site does not?)

import jax
import jax.numpy as jnp

# x is an array, which does not handle type hints well.
def rosenbrock(n: int, x: any) -> float:
    f = 0
    # i is 1-indexed to match document.
    for i in range(1, n):
        # adjust 1-based indices to 0-based python indices.
        xi = x[i-1].item()
        xip1 = x[i].item()

        fi = 100 * (xip1 - xi**2)**2 + (1 - xi)**2
        f = f + fi
    return f


# with n=2.
def rosenbrock2(x: any) -> float:
    return rosenbrock(2, x)


grad_rosenbrock2 = jax.grad(rosenbrock2)

x = jnp.array([-1.2, 1], dtype=jnp.float32).reshape(2,1)

# this line fails with the error given below
grad_rosenbrock2(x)

This last line results in:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[1].
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

I'm trying to follow the docs, and I'm confused. This is my first time using JAX or Autograd, can someone help me resolve this? Thanks!


Solution

  • The problem is that the .item() method attempts to convert an array to a static Python scalar, and since you have traced arrays within your grad transformation, conversion to a static value is not possible.

    What you need here is to convert a size-1 array to a scalar array, which you can do using .reshape(()):

    def rosenbrock(n: int, x: any) -> float:
        f = 0
        # i is 1-indexed to match document.
        for i in range(1, n):
            # adjust 1-based indices to 0-based python indices.
            xi = x[i-1].reshape(())
            xip1 = x[i].reshape(())
    
            fi = 100 * (xip1 - xi**2)**2 + (1 - xi)**2
            f = f + fi
        return f
    

    For more background on jax transformations and traced arrays, I'd recommend How to think in JAX.