I'd like to evaluate np.random.dirichlet with large dimension as quickly as possible. More precisely, I'd like a function approximating the below by at least 10 times faster. Empirically, I observed that small-dimension-version of this function outputs one or two entries that have the order of 0.1, and every other entries are so small that they are immaterial. But this observation isn't based on any rigorous assessment. The approximation doesn't need to be so accurate, but I want something not too crude, as I'm using this noise for MCTS.
def g():
np.random.dirichlet([0.03]*4840)
>>> timeit.timeit(g,number=1000)
0.35117408499991143
Assuming your alpha is fixed over components and used for many iterations you could tabulate the ppf of the corresponding gamma distribution. This is probably available as scipy.stats.gamma.ppf
but we can also use scipy.special.gammaincinv
. This function seems rather slow, so this is a siginificant upfront investment.
Here is a crude implementation of the general idea:
import numpy as np
from scipy import special
class symm_dirichlet:
def __init__(self, alpha, resolution=2**16):
self.alpha = alpha
self.resolution = resolution
self.range, delta = np.linspace(0, 1, resolution,
endpoint=False, retstep=True)
self.range += delta / 2
self.table = special.gammaincinv(self.alpha, self.range)
def draw(self, n_sampl, n_comp, interp='nearest'):
if interp != 'nearest':
raise NotImplementedError
gamma = self.table[np.random.randint(0, self.resolution,
(n_sampl, n_comp))]
return gamma / gamma.sum(axis=1, keepdims=True)
import time, timeit
t0 = time.perf_counter()
X = symm_dirichlet(0.03)
t1 = time.perf_counter()
print(f'Upfront cost {t1-t0:.3f} sec')
print('Running cost per 1000 samples of width 4840')
print('tabulated {:3f} sec'.format(timeit.timeit(
'X.draw(1, 4840)', number=1000, globals=globals())))
print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))
Sample output:
Upfront cost 13.067 sec
Running cost per 1000 samples of width 4840
tabulated 0.059365 sec
np.random.dirichlet 0.980067 sec
Better check whether it is roughly correct: