Using type(z1[0])
I get jaxlib.xla_extension.ArrayImpl
. Printing z1[0]
I get Array(0.71530414, dtype=float32)
. How can I get the actual number 0.71530414
?
I tried z1[0][0]
because z1[0]
is a kind of array with a single value, but it gives me an error: IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.
.
I tried also a different approach: I searched on the web if it was possible to convert from jaxnumpy array to a python list, but I didn't find an answer.
Can someone help me to get the value inside a jaxlib.xla_extension.ArrayImpl
object?
You can use float(x[0])
to convert x[0]
to a Python float:
In [1]: import jax.numpy as jnp
In [2]: x = jnp.array([0.71530414])
In [3]: x
Out[3]: Array([0.71530414], dtype=float32)
In [4]: x[0]
Out[4]: Array(0.71530414, dtype=float32)
In [5]: float(x[0])
Out[5]: 0.7153041362762451
If you're interested in converting the entire JAX array to a list of Python floats, you can use the tolist()
method:
In [6]: x.tolist()
Out[6]: [0.7153041362762451]