I have a function which I would like to find its maximum by optimizing two of its variables using Jax.
The current code that I have currently, which does not work, reads
import jax.numpy as jnp
import jax
import scipy
import numpy as np
def temp_func(x,y,z):
tmp = x + jnp.dot( jnp.power(y,3), jnp.tanh(z) )
return -tmp
def obj_func(xy, z):
x,y = xy[:2], xy[2:].reshape(2,2)
return jnp.sum(temp_func(jnp.array(x),jnp.array(y),z))
grad_tmp = jax.grad(obj_func, argnums=0) # x,y
xy = jnp.concatenate([np.random.rand(2), np.random.rand(2*2) ])
z= jnp.array( np.random.rand(2,2) )
print(obj_func(xy,z))
result = scipy.optimize.minimize(obj_func,
xy,
args=(z,),
method='L-BFGS-B',
jac=grad_tmp
)
With this code, I get the error ValueError: failed in converting 7th argument
g' of _lbfgsb.setulb to C/Fortran array`
Do you have any suggestions to resolve the issue?
You might think about using the jax
version of scipy.optimize.minimize
, which will automatically compute and use the derivative:
import jax.scipy.optimize
result = jax.scipy.optimize.minimize(obj_func, xy, args=(z,), method='BFGS')
That said, the results in either case are not going to be very meaningful, because your objective function is linearly decreasing in x
and y
, so it will be minimized when x, y → ∞