pythonarraysjax

How to get value of jaxlib.xla_extension.ArrayImpl


Using type(z1[0]) I get jaxlib.xla_extension.ArrayImpl. Printing z1[0] I get Array(0.71530414, dtype=float32). How can I get the actual number 0.71530414?

I tried z1[0][0] because z1[0] is a kind of array with a single value, but it gives me an error: IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0..

I tried also a different approach: I searched on the web if it was possible to convert from jaxnumpy array to a python list, but I didn't find an answer.

Can someone help me to get the value inside a jaxlib.xla_extension.ArrayImpl object?


Solution

  • You can use float(x[0]) to convert x[0] to a Python float:

    In [1]: import jax.numpy as jnp
    
    In [2]: x = jnp.array([0.71530414])
    
    In [3]: x
    Out[3]: Array([0.71530414], dtype=float32)
    
    In [4]: x[0]
    Out[4]: Array(0.71530414, dtype=float32)
    
    In [5]: float(x[0])
    Out[5]: 0.7153041362762451
    

    If you're interested in converting the entire JAX array to a list of Python floats, you can use the tolist() method:

    In [6]: x.tolist()
    Out[6]: [0.7153041362762451]