pythonjax

precision of JAX


I have a question regarding the precision of float in JAX. For the following code,

import numpy as np
import jax.numpy as jnp

print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) is:','%.60f' % np.arctan(10))

jnp.arctan(10) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000


print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
print('np.arctan(10+1e-7) is:','%.60f' % np.arctan(10+1e-7))

jnp.arctan(10+1e-7) is: 1.471127629280090332031250000000000000000000000000000000000000
np.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000

jnp gave identical results for arctan(x) for a small change of input variable (1e-7), but np did not. My question is how to let jax.numpy get the right number for a small change of x?

Any comments are appreciated.


Solution

  • JAX defaults to float32 computation, which has a relative precision of about 1E-7. This means that your two inputs are effectively identical:

    >>> np.float32(10) == np.float32(10 + 1E-7)
    True
    

    If you want 64-bit precision like NumPy, you can enable it as discussed at JAX sharp bits: double precision, and then the results will match to 64-bit precision:

    import jax
    jax.config.update('jax_enable_x64', True)
    
    import jax.numpy as jnp
    import numpy as np
    
    print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
    print('np.arctan(10) is: ','%.60f' % np.arctan(10))
    
    print('jnp.arctan(10+1e-7) is:','%.60f' % jnp.arctan(10+1e-7))
    print('np.arctan(10+1e-7) is: ','%.60f' % np.arctan(10+1e-7))
    
    jnp.arctan(10) is: 1.471127674303734700345103192375972867012023925781250000000000
    np.arctan(10) is:  1.471127674303734700345103192375972867012023925781250000000000
    jnp.arctan(10+1e-7) is: 1.471127675293833592107262120407540351152420043945312500000000
    np.arctan(10+1e-7) is:  1.471127675293833592107262120407540351152420043945312500000000
    

    (but please note that even the 64-bit precision used by Python and NumPy is only accurate to about one part in 10^16, so most of the digits in the representation you printed are inaccurate compared to the true arctan value).