pythonnumpyloopsmatrixnumba

Efficient stochastic numerical integration over many trajectories


I am implementing a numerical method for solving stochastic differential equations using the Euler-Maruyama method.

What I have works, but it is not efficient. The reason is that because of the stochastic nature of the problem, I have many trajectories. Right now, I am solving them one by one. I have the feeling I should be able to parallelize them, as they are independent.

The working code looks like this

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import time
from numba import jit, njit
import os

def A(u):
    x=u[0]
    y=u[1]
    z=u[2]
    
    omega=1/2*np.sqrt((1+8*kappa*z*z))  
    
    
    A=np.array([[-2,omega,0],
                [-omega,0,0],
                [0,0,-kappa]])

    
    du=A.dot(u)
    return du


def B(u,w):
    x=u[0]
    y=u[1]
    z=u[2]
    
    g=np.sqrt(kappa*nth)

    B=np.array([[0],
                [1],
                [1]])*g
    
    return np.reshape(B*w,len(u0))


def SDE(A,B):
    u = np.zeros((len(u0),Nmax+1,Mmax),dtype=np.complex64)
    for m in range(Mmax):
        u[:,0,m]=u0
        for n in range(0,Nmax):
            u[:,n+1,m] = u[:,n,m]+dt*A(u[:,n,m])+B(u[:,n,m],w[n,m])*np.sqrt(dt)   
            
    return u


#Parameters
kappa=0.05
nth=1.
gamma=1

Mmax=100 #number of trajectories

Tmax=10. ##max value for time
dt=0.05
Nmax=int(Tmax/dt) ##number of steps

t_list=np.arange(0,Tmax+dt/2,dt)


w = np.random.randn(Nmax+1,Mmax)

u0 = np.array([1., 0., np.sqrt(nth)/2])

u_t=SDE(A,B)

u_mean=np.mean(u_t,axis=2)

This code is a simplification of my real code, where I have a much larger dimension of the system and many more trajectories.
Notice how this is not efficient, because as I increase Mmax, I have to loop over them.

Ideally, I would like my solver to look something like

def SDE(A,B):
    u = np.zeros((len(u0),Nmax+1,Mmax),dtype=np.complex64)
    u[:,0,:]=u0
    for n in range(0,Nmax):
       u[:,n+1,:] = u[:,n,:]+dt*A(u[:,n,:])+B(u[:,n,:],w[n,:])*np.sqrt(dt)   
            
    return u

i.e., to be able to neglect the loop over m and just do it in a parallel fashion. However, naively doing so does not work.

Another ideal way to make it more efficient would be to use Numba. However, after many tries, I have not been able to implement njit with the SDE solver I define.


Solution

  • Here are the performance issues I found in the provided code:

    Besides, np.reshape(B*w,len(u0)) is confusing since B is of size 3 so u0 must be also of size 3 and the reshape seems useless. Note that np.complex64 is for simple-precision (as stated in the comments).

    Here is the resulting Numba code:

    import numba as nb  # new
    
    @nb.njit('(complex128[:], complex128[::1])')
    def A(u, res):
        x=u[0]
        y=u[1]
        z=u[2]
    
        omega = 0.5 * np.sqrt((1+8*kappa*z*z))  
    
        res[0] = -2 * x + omega * y
        res[1] = -omega * x
        res[2] = -kappa * z
        return res
    
    @nb.njit('(complex128[:], float64, complex128[::1])')
    def B(u, w, res):
        g = np.sqrt(kappa*nth)
        res[0] = 0
        res[1] = g * w
        res[2] = g * w
        return res
    
    # No signature is provided so the first call will be much slower
    # But providing a signature here is complicated since A and B are functions
    @nb.njit
    def SDE(A,B):
        u = np.zeros((len(u0),Nmax+1,Mmax), dtype=np.complex128)
        sqrt_dt = np.sqrt(dt)
        for m in range(Mmax):
            u[:,0,m] = u0
            tmp1 = np.empty(3, dtype=np.complex128)
            tmp2 = np.empty(3, dtype=np.complex128)
            for n in range(0,Nmax):
                A(u[:,n,m],tmp1)
                B(u[:,n,m],w[n,m],tmp2)
                for i in range(3):
                    u[i,n+1,m] = u[i,n,m] + dt * tmp1[i] + tmp2[i] * sqrt_dt
        return u
    

    This is about 500 times faster on my machine (with a i5-9600KF CPU).

    I think there is no need to use multiple threads once the code is optimized since the computation is finally pretty fast. If this is not enough, you can add the flag parallel=True and replace for m in range(Mmax) with for m in nb.prange(Mmax). However, this will not scale well due to false sharing caused by the bad memory layout. As stated in the above list, you should swap the axis 1 and 3 so to fix this issue.

    In the end, the final code should look like this once parallelized and with a better memory layout:

    # A and B are the same as before
    
    # No signature is provided so the first call will be much slower
    # But providing a signature here is complicated since A and B are functions
    @nb.njit(parallel=True)
    def SDE(A,B):
        u = np.zeros((Mmax,Nmax+1,len(u0)), dtype=np.complex128)
        sqrt_dt = np.sqrt(dt)
        for m in nb.prange(Mmax):
            u[m,0,:] = u0
            tmp1 = np.empty(3, dtype=np.complex128)
            tmp2 = np.empty(3, dtype=np.complex128)
            for n in range(0,Nmax):
                A(u[m,n,:],tmp1)
                B(u[m,n,:],w[m,n],tmp2)
                for i in range(3):
                    u[m,n+1,i] = u[m,n,i] + dt * tmp1[i] + tmp2[i] * sqrt_dt
        return u
    
    
    # [...] same code
    
    w = np.random.randn(Nmax+1,Mmax).T.copy()
    u0 = np.array([1., 0., np.sqrt(nth)/2])
    u_t=SDE(A,B)
    u_mean=np.mean(u_t,axis=0)
    

    This code is about 2000 times faster.

    Note w is a global variable so it is considered as a compile-time constant by Numba (ie. it should never change during the application life-time), you should pass it in parameter if it is not the case.

    By the way, note that Julia might be a better language for such a computation since the standard implementation is a JIT-compiler and we can easily avoid the creation of new temporary arrays (though creating new array is still quite expensive even in Julia).