I am trying to modify Ben Moseley's code available on github https://github.com/benmoseley/FBPINNs. My intention is to insert a vector of values into the loss fn that is dependent on x y coordinates, and I need the original vector Z to be interpolated as a function of x and y, and then the values at the same coordinates with which the algorithm samples x and y are extracted, so that the values match. The problem I have encountered is that within loss fn I cannot use libraries other than JAX and to my knowledge there are no functions within JAX to interpolate in 2D.
I'm trying to get around the problem in every way but I'm not succeeding, one of my ideas was to extrapolate the x,y points sampled by the algorithm but I'm not succeeding, the code is really very articulated. Would anyone be able to give me any advice/help on this?
There would be the function jax.scipy.ndimage.map_coordinates but it doesn't work properly and the points it extrapolates are meaningless.
If linear or nearest-neighbor interpolation is sufficient, you may be able to do what you need with jax.scipy.interpolate.RegularGridInterpolator
If you need something more sophisticated, like spline interpolation, there is nothing included in jax
itself. That said, you may be able to find downstream implementations that work for you. One I came across that might be worth trying is in the jax_cosmo
project: https://jax-cosmo.readthedocs.io/en/latest/_modules/jax_cosmo/scipy/interpolate.html.