I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t looks like this (and, yeah, I need to use jax):
def evolve(u, t):
# return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise
def simulate(x, t):
k = jax.numpy.floor(t / dt).astype(int)
u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt), u)
Now, the pain comes with the noise. I'm a C++-guy who only occasionally needs to use Python for research/scientific work. And I really don't understand how I need (or should) implement PRNG splitting here. I guess I would change evolve to
def evolve(u, t, key):
noise = jax.random.multivariate_normal(key, jax.numpy.zeros(d), covariance_matrix, shape = (n,))
# return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise
But that will not work properly I guess. If I got it right, I need to use jax.random.split to split the key. Cause if I don't, I end up with correlated samples. But how and where do I need to split?
Also: I guess I would need to modify simulate to def simulate(x, t, key). But then, should simulate also return the modified key?
And to make it even more complicated: I actually wrap simulate into a batch_simulate function which uses jax.vmap to process a whole batch of x's and t's. How do I pass the PRNG to that batch_simulate function, how do I pass it (and broadcast it) to jax.vmap and what should batch_forward return? At first glance, it seems to me that it would take a single PRNG and split it into many (due to the vmap). But what does the caller of batch_forward do then ...
Completely lost on this. Any help is highly appreciated!
If I understand your setup correctly, you should make both evolve and simulate accept a key, and within simulate, use fold_in to generate unique keys for the loop:
def evolve(u, t, key):
...
def simulate(x, t, key):
k = jax.numpy.floor(t / dt).astype(int)
u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt, jax.random.fold_in(key, i)), u)
Then if you want to vmap over simulate, you can split the key and map over it:
x_batch = ... # your batched x inputs
t_batch = ... # your batched t inputs
key_batch = jax.random.split(key, x_batch.shape[0])
batch_result = jax.vmap(simulate)(x_batch, t_batch, key_batch)