There is a neural network that takes as an input a two variables: net(x, t)
, where x
is usually d-dim, and t
is a scalar. The NN outputs a vector of length d
. x
and t
might be batches, so x
is of shape (b, d)
, and t
is (b, 1)
, and the output is (b,d)
. I need to find
d out/dt
of the NN output. It should be d dim vector (or (batch, d)
);d out/dx
of the NNx
, it still should be (batch, d)
vectorSince the NN doesn’t output a scalar, I don’t think Jax grad would help here. I know how to do what I described in torch, but I’m totally new to JAX. I’d really appreciate your help with this question!
There is an example:
import jaxlib
import jax
from jax import numpy as jnp
import flax.linen as nn
from flax.training import train_state
class NN(nn.Module):
hid_dim : int # Number of hidden neurons
output_dim : int # Number of output neurons
@nn.compact
def __call__(self, x, t):
out = jnp.hstack((x, t))
out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
out = nn.Dense(features=self.output_dim)(out)
return out
d = 3
batch_size = 10
net = NN(hid_dim=100, output_dim=d)
rng_nn, rng_inp1, rng_inp2 = jax.random.split(jax.random.PRNGKey(100), 3)
inp_x = jax.random.normal(rng_inp1, (1, d)) # batch, d
inp_t = jax.random.normal(rng_inp2, (1, 1))
params_net = net.init(rng_nn, inp_x, inp_t)
x = jax.random.normal(rng_inp2, (batch_size, d)) # batch, d
t = jax.random.normal(rng_inp1, (batxh_size, 1))
out_net = net.apply(params_net, x, t)
optimizer = optax.adam(1e-3)
model_state = train_state.TrainState.create(apply_fn=net.apply,
params= params_net,
tx=optimizer)
I'd like to calculate an $L_2$ loss based on some derivatives of the NN's outputs according to its inputs. For example, I'd like to have d f/dx
or d f/dt
where f
is the NN. ALso the gradient of the divergence by x. I assume it'd be something like
def find_derivatives(net, params, X, t):
d_dt = lambda net, params, X, t: jax.jvp(lambda time: net(params, X, t), (t, ), (jnp.ones_like(t), ))
d_dx = lambda net, params, X, t: jax.jvp(lambda X: net(params, X, t), (Xs_all, ), (jnp.ones_like(X), ))
out_f, df_dt = d_dt(net.apply, params, X, t)
d_ddx = lambda net, params, X, t: d_dx(lambda params, X, t: d_dx(net, params, X, t)[1], params, X, t)
df_dx, df_ddx = d_ddx(net.apply, params, X, t)
return out_f, df_dt, df_dx, df_ddx
out_f, df_dt, df_dx, df_ddx = find_derivatives(net, params_net, x, t)
I would avoid using jax.jvp
here, because this is meant as a lower-level API. You can use jax.jacobian
to compute the Jacobian (since your function has multiple outputs), and vmap
for batching. For example:
df_dx = jax.vmap(
jax.jacobian(net.apply, argnums=1),
in_axes=(None, 0, 0)
)(params_net, x, t)
print(df_dx.shape) # (10, 3, 3)
df_dt = jax.vmap(
jax.jacobian(net.apply, argnums=2),
in_axes=(None,0, 0)
)(params_net, x, t).reshape(10, 3)
print(df_dt.shape) # (10, 3)
Here df_dx
is the batch-wise Jacobian of the 3-dimensional output vector with respect to the 3-dimensional x input vector, and df_dt
is the batch-wise gradient of the 3-dimensional output vector with respect to the input t
.