pythonjaxnumpyro

JAX with JIT and custom differentiation


I am working with JAX through numpyro. Specially, I want to use a B-spline function (e.g. implemented in scipy.interpolate.BSpline) to transform different points into a spline where the input depends on some of the parameters in the model. Thus, I need to be able to differentiate the B-spline in JAX (only in the input argument and not in the knots or the integer order (of course!)).

I can easily use jax.custom_vjp but not when JIT is used as it is in numpyro. I looked at the following:

  1. https://github.com/google/jax/issues/1142
  2. https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

and it seems like the best hope is to use a callback. Though, I cannot figure out entirely how that would work. At https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support

the TensorFlow example with reverse mode autodiff seem not to use JIT.

The example

Here is Python code that works without JIT (see the b_spline_basis() function):

from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax

doubleArray = npt.NDArray[np.double]

# see
#   https://stackoverflow.com/q/74699053/5861244
#   https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray:  # type: ignore[no-any-unimported]
    out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))

    for col_index in range(out.shape[1] - 1):
        scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
        if scale != 0:
            out[:, col_index] = -deriv_basis[:, col_index + 1] / scale

    for col_index in range(1, out.shape[1]):
        scale = spline.t[col_index + spline.k] - spline.t[col_index]
        if scale != 0:
            out[:, col_index] += deriv_basis[:, col_index] / scale

    return float(spline.k) * out


def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray:  # type: ignore[no-any-unimported]
    if deriv == 0:
        return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
    elif spline.k <= 0:
        return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))

    return _b_spline_deriv_inner(
        spline=spline,
        deriv_basis=_b_spline_eval(
            BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
        ),
    )


@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
    return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
        :, 1:
    ]


def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
    spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
    return (
        _b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
        _b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
    )


def b_spline_basis_bwd(
    knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
    return (jax.numpy.sum(partials * grad, axis=1),)


b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)

if __name__ == "__main__":
    # tests

    knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
    x = np.array([0.1, 0.5, 0.9])
    order = 3

    def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
        weights = jax.numpy.arange(1, basis.shape[1] + 1)

        def test_func(x: doubleArray) -> doubleArray:
            return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights))  # type: ignore[no-any-return]

        assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
        assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))

    deriv0 = np.transpose(
        np.array(
            [
                0.684,
                0.166666666666667,
                0.00133333333333333,
                0.096,
                0.444444444444444,
                0.0355555555555555,
                0.004,
                0.351851851851852,
                0.312148148148148,
                0,
                0.037037037037037,
                0.650962962962963,
            ]
        ).reshape(-1, 3)
    )

    deriv1 = np.transpose(
        np.array(
            [
                2.52,
                -1,
                -0.04,
                1.68,
                -0.666666666666667,
                -0.666666666666667,
                0.12,
                1.22222222222222,
                -2.29777777777778,
                0,
                0.444444444444444,
                3.00444444444444,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv0, deriv1, deriv=0)

    deriv2 = np.transpose(
        np.array(
            [
                -69.6,
                4,
                0.8,
                9.6,
                -5.33333333333333,
                5.33333333333333,
                2.4,
                -2.22222222222222,
                -15.3777777777778,
                0,
                3.55555555555556,
                9.24444444444445,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv1, deriv2, deriv=1)

    deriv3 = np.transpose(
        np.array(
            [
                504,
                -8,
                -8,
                -144,
                26.6666666666667,
                26.6666666666667,
                24,
                -32.8888888888889,
                -32.8888888888889,
                0,
                14.2222222222222,
                14.2222222222222,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv2, deriv3, deriv=2)

Solution

  • The best way to accomplish this is probably using a combination of custom_jvp and jax.pure_callback.

    Unfortunately, pure_callback is relatively new and does not have great documentation yet, but you can find examples of its use in the JAX user forums (for example here).

    Copied here for posterity, this is an example of computing the sine and cosine via numpy callbacks in jit-compatible code with custom JVP rules for autodiff.

    import jax
    import numpy as np
    jax.config.update('jax_enable_x64', True)
    
    @jax.custom_jvp
    def np_sin(x):
      # Compute the sine by calling-back to np.sin on the host.
      return jax.pure_callback(np.sin, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)
    
    @np_sin.defjvp
    def _np_sin_jvp(primals, tangents):
      x, = primals
      dx, = tangents
      return np_sin(x), np_cos(x) * dx  #d sin(x) = cos(x) dx
    
    @jax.custom_jvp
    def np_cos(x):
      # Compute the cosine by calling-back to np.cos on the host.
      return jax.pure_callback(np.cos, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)
    
    @np_cos.defjvp
    def _np_cos_jvp(primals, tangents):
      x, = primals
      dx, = tangents
      return np_cos(x), -np_sin(x) * dx  # d cos(x) = -sin(x) dx
    
    
    print(np_sin(1.0))
    # 0.8414709848078965
    print(np_cos(1.0))
    # 0.5403023058681398
    print(jax.jit(jax.grad(np_sin))(1.0))
    # 0.5403023058681398
    

    Note that since pure_callback operates by sending data back to the host, it will generally have a lot of overhead on accelerators like GPU and TPU, although in a single-CPU setting this kind of approach can perform well.