I have a question regarding the precision of float in JAX. For the following code,
import numpy as np
import jax.numpy as jnp
print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is:','%.60f' % np.arctan(10))
jnp.arctan(10) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is:','%.60f' % np.arctan(10+1e-7))
jnp.arctan(10+1e-7) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
jnp gave identical results for arctan(x) for a small change of input variable (1e-7), but np did not. My question is how to let jax.numpy get the right number for a small change of x?
Any comments are appreciated.
JAX defaults to float32 computation, which has a relative precision of about 1E-7
. This means that your two inputs are effectively identical:
>>> np.float32(10) == np.float32(10 + 1E-7)
True
If you want 64-bit precision like NumPy, you can enable it as discussed at JAX sharp bits: double precision, and then the results will match to 64-bit precision:
import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp
import numpy as np
print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is: ','%.60f' % np.arctan(10))
print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is: ','%.60f' % np.arctan(10+1e-7))
jnp.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
np.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
jnp.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
np.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
(but please note that even the 64-bit precision used by Python and NumPy is only accurate to about one part in 10^16, so most of the digits in the representation you printed are inaccurate compared to the true arctan value).