pythonscipystatistics

Where is scipy.stats.dirichlet_multinomial.rvs?


I wanted to draw samples from a Dirichlet-multinomial distribution using SciPy. Unfortunately it seems that scipy.stats.dirichlet_multinomial does not define the rvs method that other distributions use to generate random samples.

I think this would be equivalent to the following for a single sample:

import scipy.stats as sps

def dirichlet_multinomial_sample(alpha, n, **kwargs):
    kwargs['size'] = 1 # force size to 1 for simplicity
    p = sps.dirichlet.rvs( alpha=alpha, **kwargs )
    return sps.multinomial.rvs( n=n, p=p.ravel(), **kwargs )

Multiple samples (i.e. size > 1) could be drawn similarly with a little bit more work to make it efficient. This seems easy enough to implement. My two questions are:


Solution

  • Is the above implementation correct?

    This looks correct to me. Based on the discussion in the PR implementing multinomial, SciPy did implement a bit of code to generate samples from a multinomial Dirichlet, but the code is only part of a test, not a public API.

    One of the reviewers briefly touches on what you mention:

    optional, probably follow-up PR add RVS method (as demonstrated in test_moments)

    Looking up the code from the test-case they're referencing, here's what it's doing.

    https://github.com/scipy/scipy/pull/17211/files#diff-a998a313f078eba79aeb6347a65117e0a0c4542a4d778a4cfd398f1737380a71R3030

            rng = np.random.default_rng(28469824356873456)
            n = rng.integers(1, 100)
            alpha = rng.random(size=5) * 10
            dist = dirichlet_multinomial(alpha, n)
    
            # Generate a random sample from the distribution using NumPy
            m = 100000
            p = rng.dirichlet(alpha, size=m)
            x = rng.multinomial(n, p, size=m)
    

    That would appear to be essentially the same thing you're doing, only using the equivalent NumPy API rather than the SciPy API. See 1 2.

    Another sticking point that was discussed was whether this would require SciPy to bump the minimum version of NumPy.

    The problem is that it will be slow and cumbersome without NumPy 1.22 (vectorization over multinomial shape parameters), but that should be the minimum supported version before the next release of SciPy, so it should be OK.

    Since this message, SciPy has changed to a minimum version of 1.23.5, which means this is no longer a problem.

    If it is, how can I suggest this enhancement to SciPy developers?

    You can open an issue, and ask for them to fix this.

    You could also try fixing it yourself. If you do this, I would still recommend that you open an issue first. This will ensure that you and the maintainer are on the same page.

    If you decide to fix it yourself, I would recommend that you read the following things: