performancenumpyoptimizationfunction-approximation

A very quick method to approximate np.random.dirichlet with large dimension


I'd like to evaluate np.random.dirichlet with large dimension as quickly as possible. More precisely, I'd like a function approximating the below by at least 10 times faster. Empirically, I observed that small-dimension-version of this function outputs one or two entries that have the order of 0.1, and every other entries are so small that they are immaterial. But this observation isn't based on any rigorous assessment. The approximation doesn't need to be so accurate, but I want something not too crude, as I'm using this noise for MCTS.

def g():
   np.random.dirichlet([0.03]*4840)

>>> timeit.timeit(g,number=1000)
0.35117408499991143

Solution

  • Assuming your alpha is fixed over components and used for many iterations you could tabulate the ppf of the corresponding gamma distribution. This is probably available as scipy.stats.gamma.ppf but we can also use scipy.special.gammaincinv. This function seems rather slow, so this is a siginificant upfront investment.

    Here is a crude implementation of the general idea:

    import numpy as np
    from scipy import special
    
    class symm_dirichlet:
        def __init__(self, alpha, resolution=2**16):
            self.alpha = alpha
            self.resolution = resolution
            self.range, delta = np.linspace(0, 1, resolution,
                                            endpoint=False, retstep=True)
            self.range += delta / 2
            self.table = special.gammaincinv(self.alpha, self.range)
        def draw(self, n_sampl, n_comp, interp='nearest'):
            if interp != 'nearest':
                raise NotImplementedError
            gamma = self.table[np.random.randint(0, self.resolution,
                                                 (n_sampl, n_comp))]
            return gamma / gamma.sum(axis=1, keepdims=True)
    
    import time, timeit
    
    t0 = time.perf_counter()
    X = symm_dirichlet(0.03)
    t1 = time.perf_counter()
    print(f'Upfront cost {t1-t0:.3f} sec')
    print('Running cost per 1000 samples of width 4840')
    print('tabulated           {:3f} sec'.format(timeit.timeit(
        'X.draw(1, 4840)', number=1000, globals=globals())))
    print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
        'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))
    

    Sample output:

    Upfront cost 13.067 sec
    Running cost per 1000 samples of width 4840
    tabulated           0.059365 sec
    np.random.dirichlet 0.980067 sec
    

    Better check whether it is roughly correct:

    enter image description here