I'm trying to generate multinomial random variables as fast as possible. And I learned that gsl_ran_multinomial
could be a good choice. However, I tried to use it based on the answer in this post: https://stackoverflow.com/a/23100665/21039115, and the results were always wrong.
In detail, my code is
// [[Rcpp::export]]
arma::ivec rmvn_gsl(int K, arma::vec prob) {
gsl_rng *s = gsl_rng_alloc(gsl_rng_mt19937); // Create RNG seed
arma::ivec temp(K);
gsl_ran_multinomial(s, K, 1, prob.begin(), (unsigned int *) temp.begin());
gsl_rng_free(s); // Free memory
return temp;
}
And the result was something like
rmvn_gsl(3, c(0.2, 0.7, 0.1))
[,1]
[1,] 1
[2,] 0
[3,] 0
which is ridiculous.
I was wondering if there was any problem exist in the code... I couldn't find any other examples to compare. I appreciate any help!!!
UPDATE:
I found the primary problem here is that I didn't set a random seed, and it seems that gsl
has its own default seed (FYI: https://stackoverflow.com/a/32939816/21039115). Once I set the seed by time, the code worked. But I will go with rmultinom
since it can even be faster than gsl_ran_multinomial
based on microbenchmark.
Anyway, @Dirk Eddelbuettel provided a great example of implementing gsl_ran_multinomial
below. Just pay attention to the random seeds issue if someone met the same problem as me.
Here is a complete example taking a double
vector of probabilities and returning an unsigned integer vector (at the compiled level) that is mapped to an integer vector by the time we are back in R:
#include <RcppGSL.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
// [[Rcpp::depends(RcppGSL)]]
// [[Rcpp::export]]
std::vector<unsigned int> foo(int n, const std::vector <double> p) {
int k = p.size();
std::vector<unsigned int> nv(k);
gsl_rng_env_setup();
gsl_rng *s = gsl_rng_alloc(gsl_rng_mt19937); // Create RNG instance
gsl_ran_multinomial(s, k, n, &(p[0]), &(nv[0]));
gsl_rng_free(s);
return nv;
}
/*** R
foo(400, c(0.1, 0.2, 0.3, 0.4))
*/
> Rcpp::sourceCpp("~/git/stackoverflow/75165241/answer.cpp")
> foo(400, c(0.1, 0.2, 0.3, 0.4))
[1] 37 80 138 145
>