pythonrandompython-3.8jax

How to handle PRNG splitting in a jax.vmap context?


I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t looks like this (and, yeah, I need to use jax):

def evolve(u, t):
    # return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise

def simulate(x, t):
    k = jax.numpy.floor(t / dt).astype(int)
    u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt), u)

Now, the pain comes with the noise. I'm a C++-guy who only occasionally needs to use Python for research/scientific work. And I really don't understand how I need (or should) implement PRNG splitting here. I guess I would change evolve to

def evolve(u, t, key):
    noise = jax.random.multivariate_normal(key, jax.numpy.zeros(d), covariance_matrix, shape = (n,))
    # return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise

But that will not work properly I guess. If I got it right, I need to use jax.random.split to split the key. Cause if I don't, I end up with correlated samples. But how and where do I need to split?

Also: I guess I would need to modify simulate to def simulate(x, t, key). But then, should simulate also return the modified key?

And to make it even more complicated: I actually wrap simulate into a batch_simulate function which uses jax.vmap to process a whole batch of x's and t's. How do I pass the PRNG to that batch_simulate function, how do I pass it (and broadcast it) to jax.vmap and what should batch_forward return? At first glance, it seems to me that it would take a single PRNG and split it into many (due to the vmap). But what does the caller of batch_forward do then ...

Completely lost on this. Any help is highly appreciated!


Solution

  • If I understand your setup correctly, you should make both evolve and simulate accept a key, and within simulate, use fold_in to generate unique keys for the loop:

    def evolve(u, t, key):
        ...
    
    def simulate(x, t, key):
        k = jax.numpy.floor(t / dt).astype(int)
        u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt, jax.random.fold_in(key, i)), u)
    

    Then if you want to vmap over simulate, you can split the key and map over it:

    x_batch = ...  # your batched x inputs
    t_batch = ...  # your batched t inputs
    key_batch = jax.random.split(key, x_batch.shape[0])
    
    batch_result = jax.vmap(simulate)(x_batch, t_batch, key_batch)