sympyjaxautomatic-differentiation

Derivative in JAX and Sympy not coinciding


For this vectorial function I want to evaluate the jacobian:

import jax
import jax.numpy as jnp

def myf(arr, phi_0, phi_1, phi_2, lambda_0, R):
    arr = jnp.deg2rad(arr)
    phi_0 = jnp.deg2rad(phi_0)
    phi_1 = jnp.deg2rad(phi_1)
    phi_2 = jnp.deg2rad(phi_2)
    lambda_0 = jnp.deg2rad(lambda_0)
    
    n = jnp.sin(phi_1)
    
    F = 2.0
    rho_0 = 1.0
    rho = R*F*(1/jnp.tan(jnp.pi/4 + arr[1]/2))**n
    x_L = rho*jnp.sin(n*(arr[0] - lambda_0))
    y_L = rho_0 - rho*jnp.cos(n*(arr[0] - lambda_0))
    
    return jnp.array([x_L,y_L])


arr = jnp.array([-18.1, 29.9])

jax.jacobian(myf)(arr, 29.5, 29.5, 29.5, -17.0, R=1)

I obtain

[[ 0.01312758  0.00014317]
 [-0.00012411  0.01514319]]

I'm in shock with these values. Take for instance the element [0][0], 0.01312758. We know it's the partial of x_L with respect to the variable arr[0]. Whether by hand or using sympy that derivative is ~0.75.

from sympy import *

x, y = symbols('x y')

x_L = (2.0*(1/tan(3.141592/4 + y/2))**0.492)*sin(0.492*(x + 0.2967))
deriv = Derivative(x_L, x)
deriv.doit()
deriv.doit().evalf(subs={x: -0.3159, y: 0.52})

0.752473089673695

(inserting x, y, that are arr[0] and arr[1] already in radians). This is also the result I obtain by hand. What is happening with Jax results? I can't see what I'm doing bad.


Solution

  • Your JAX snippet inputs degrees, and so its gradient has units of 1/degrees, while your sympy snippet inputs radians, and so its gradient has units of 1/radians. If you convert the jax outputs to 1/radians (i.e. multiply the jax outputs by 180 / pi), you'll get the result you're looking for:

    result = jax.jacobian(myf)(arr, 29.5, 29.5, 29.5, -17.0, R=1)
    print(result * 180 / jnp.pi)
    
    [[ 0.7521549   0.00820279]
     [-0.00711098  0.8676407 ]]
    

    Alternatively, you could rewrite myf to accept inputs in units of radians and get the expected result by taking its gradient directly.