I wanted to draw samples from a Dirichlet-multinomial distribution using SciPy. Unfortunately it seems that scipy.stats.dirichlet_multinomial
does not define the rvs
method that other distributions use to generate random samples.
I think this would be equivalent to the following for a single sample:
import scipy.stats as sps
def dirichlet_multinomial_sample(alpha, n, **kwargs):
kwargs['size'] = 1 # force size to 1 for simplicity
p = sps.dirichlet.rvs( alpha=alpha, **kwargs )
return sps.multinomial.rvs( n=n, p=p.ravel(), **kwargs )
Multiple samples (i.e. size > 1
) could be drawn similarly with a little bit more work to make it efficient. This seems easy enough to implement. My two questions are:
Is the above implementation correct?
This looks correct to me. Based on the discussion in the PR implementing multinomial, SciPy did implement a bit of code to generate samples from a multinomial Dirichlet, but the code is only part of a test, not a public API.
One of the reviewers briefly touches on what you mention:
optional, probably follow-up PR add RVS method (as demonstrated in
test_moments
)
Looking up the code from the test-case they're referencing, here's what it's doing.
rng = np.random.default_rng(28469824356873456)
n = rng.integers(1, 100)
alpha = rng.random(size=5) * 10
dist = dirichlet_multinomial(alpha, n)
# Generate a random sample from the distribution using NumPy
m = 100000
p = rng.dirichlet(alpha, size=m)
x = rng.multinomial(n, p, size=m)
That would appear to be essentially the same thing you're doing, only using the equivalent NumPy API rather than the SciPy API. See 1 2.
Another sticking point that was discussed was whether this would require SciPy to bump the minimum version of NumPy.
The problem is that it will be slow and cumbersome without NumPy 1.22 (vectorization over multinomial shape parameters), but that should be the minimum supported version before the next release of SciPy, so it should be OK.
Since this message, SciPy has changed to a minimum version of 1.23.5, which means this is no longer a problem.
If it is, how can I suggest this enhancement to SciPy developers?
You can open an issue, and ask for them to fix this.
You could also try fixing it yourself. If you do this, I would still recommend that you open an issue first. This will ensure that you and the maintainer are on the same page.
If you decide to fix it yourself, I would recommend that you read the following things:
rvs()
, such as multivariate_normal.rvs()
.rvs()
.