For this vectorial function I want to evaluate the jacobian:
import jax
import jax.numpy as jnp
def myf(arr, phi_0, phi_1, phi_2, lambda_0, R):
arr = jnp.deg2rad(arr)
phi_0 = jnp.deg2rad(phi_0)
phi_1 = jnp.deg2rad(phi_1)
phi_2 = jnp.deg2rad(phi_2)
lambda_0 = jnp.deg2rad(lambda_0)
n = jnp.sin(phi_1)
F = 2.0
rho_0 = 1.0
rho = R*F*(1/jnp.tan(jnp.pi/4 + arr[1]/2))**n
x_L = rho*jnp.sin(n*(arr[0] - lambda_0))
y_L = rho_0 - rho*jnp.cos(n*(arr[0] - lambda_0))
return jnp.array([x_L,y_L])
arr = jnp.array([-18.1, 29.9])
jax.jacobian(myf)(arr, 29.5, 29.5, 29.5, -17.0, R=1)
I obtain
[[ 0.01312758 0.00014317]
[-0.00012411 0.01514319]]
I'm in shock with these values. Take for instance the element [0][0], 0.01312758
. We know it's the partial of x_L
with respect to the variable arr[0]
. Whether by hand or using sympy that derivative is ~0.75.
from sympy import *
x, y = symbols('x y')
x_L = (2.0*(1/tan(3.141592/4 + y/2))**0.492)*sin(0.492*(x + 0.2967))
deriv = Derivative(x_L, x)
deriv.doit()
deriv.doit().evalf(subs={x: -0.3159, y: 0.52})
0.752473089673695
(inserting x, y
, that are arr[0]
and arr[1]
already in radians). This is also the result I obtain by hand. What is happening with Jax results? I can't see what I'm doing bad.
Your JAX snippet inputs degrees, and so its gradient has units of 1/degrees, while your sympy snippet inputs radians, and so its gradient has units of 1/radians. If you convert the jax outputs to 1/radians (i.e. multiply the jax outputs by 180 / pi), you'll get the result you're looking for:
result = jax.jacobian(myf)(arr, 29.5, 29.5, 29.5, -17.0, R=1)
print(result * 180 / jnp.pi)
[[ 0.7521549 0.00820279]
[-0.00711098 0.8676407 ]]
Alternatively, you could rewrite myf
to accept inputs in units of radians and get the expected result by taking its gradient directly.