I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t
looks like this (and, yeah, I need to use jax):
def evolve(u, t):
# return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise
def simulate(x, t):
k = jax.numpy.floor(t / dt).astype(int)
u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt), u)
Now, the pain comes with the noise. I'm a C++-guy who only occasionally needs to use Python for research/scientific work. And I really don't understand how I need (or should) implement PRNG splitting here. I guess I would change evolve
to
def evolve(u, t, key):
noise = jax.random.multivariate_normal(key, jax.numpy.zeros(d), covariance_matrix, shape = (n,))
# return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise
But that will not work properly I guess. If I got it right, I need to use jax.random.split
to split the key
. Cause if I don't, I end up with correlated samples. But how and where do I need to split?
Also: I guess I would need to modify simulate
to def simulate(x, t, key)
. But then, should simulate
also return the modified key
?
And to make it even more complicated: I actually wrap simulate
into a batch_simulate
function which uses jax.vmap
to process a whole batch of x
's and t
's. How do I pass the PRNG to that batch_simulate
function, how do I pass it (and broadcast it) to jax.vmap
and what should batch_forward
return? At first glance, it seems to me that it would take a single PRNG and split it into many (due to the vmap
). But what does the caller of batch_forward
do then ...
Completely lost on this. Any help is highly appreciated!
If I understand your setup correctly, you should make both evolve
and simulate
accept a key, and within simulate
, use fold_in
to generate unique keys for the loop:
def evolve(u, t, key):
...
def simulate(x, t, key):
k = jax.numpy.floor(t / dt).astype(int)
u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt, jax.random.fold_in(key, i)), u)
Then if you want to vmap
over simulate
, you can split the key and map over it:
x_batch = ... # your batched x inputs
t_batch = ... # your batched t inputs
key_batch = jax.random.split(key, x_batch.shape[0])
batch_result = jax.vmap(simulate)(x_batch, t_batch, key_batch)