pythondeep-learningjaxautogradflax

Getting derivatives of NNs according to its inputs by batches in JAX


There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of shape (b, d), and t is (b, 1), and the output is (b,d). I need to find

Since the NN doesn’t output a scalar, I don’t think Jax grad would help here. I know how to do what I described in torch, but I’m totally new to JAX. I’d really appreciate your help with this question!

There is an example:

import jaxlib
import jax
from jax import numpy as jnp
import flax.linen as nn
from flax.training import train_state



class NN(nn.Module):
    hid_dim : int # Number of hidden neurons
    output_dim : int # Number of output neurons

    @nn.compact  
    def __call__(self, x, t):
        out = jnp.hstack((x, t))
        out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
        out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
        out = nn.Dense(features=self.output_dim)(out)
        return out

d = 3
batch_size = 10
net = NN(hid_dim=100, output_dim=d)

rng_nn, rng_inp1, rng_inp2 = jax.random.split(jax.random.PRNGKey(100), 3)
inp_x = jax.random.normal(rng_inp1, (1, d)) # batch, d
inp_t = jax.random.normal(rng_inp2, (1, 1))
params_net = net.init(rng_nn, inp_x, inp_t)

x = jax.random.normal(rng_inp2, (batch_size, d)) # batch, d
t = jax.random.normal(rng_inp1, (batxh_size, 1))

out_net = net.apply(params_net, x, t)

optimizer = optax.adam(1e-3)

model_state = train_state.TrainState.create(apply_fn=net.apply,
                                            params= params_net,
                                            tx=optimizer)

I'd like to calculate an $L_2$ loss based on some derivatives of the NN's outputs according to its inputs. For example, I'd like to have d f/dx or d f/dt where f is the NN. ALso the gradient of the divergence by x. I assume it'd be something like

def find_derivatives(net, params, X, t):
    d_dt = lambda net, params, X, t: jax.jvp(lambda time: net(params, X, t), (t, ), (jnp.ones_like(t), ))
    d_dx = lambda net, params, X, t: jax.jvp(lambda X: net(params, X, t), (Xs_all, ), (jnp.ones_like(X), ))
    out_f, df_dt = d_dt(net.apply, params, X, t)

    d_ddx = lambda net, params, X, t: d_dx(lambda params, X, t: d_dx(net, params, X, t)[1], params, X, t)
    df_dx, df_ddx = d_ddx(net.apply, params, X, t)
    
    return out_f, df_dt, df_dx, df_ddx


out_f, df_dt, df_dx, df_ddx = find_derivatives(net, params_net, x, t)

Solution

  • I would avoid using jax.jvp here, because this is meant as a lower-level API. You can use jax.jacobian to compute the Jacobian (since your function has multiple outputs), and vmap for batching. For example:

    df_dx = jax.vmap(
        jax.jacobian(net.apply, argnums=1),
        in_axes=(None, 0, 0)
      )(params_net, x, t)
    print(df_dx.shape)  # (10, 3, 3)
    
    df_dt = jax.vmap(
        jax.jacobian(net.apply, argnums=2),
        in_axes=(None,0, 0)
      )(params_net, x, t).reshape(10, 3)
    print(df_dt.shape)  # (10, 3)
    

    Here df_dx is the batch-wise Jacobian of the 3-dimensional output vector with respect to the 3-dimensional x input vector, and df_dt is the batch-wise gradient of the 3-dimensional output vector with respect to the input t.