scalascala-breezemledirichlet

Scala Breeze Dirichlet distribution parameter estimation


I am trying to estimate parameters (Dirichlet distribution) for a data set using Scala's breeze lib. I already have a working python (pandas/dataframes) and R code for it but I was curious as to how to do it in Scala. Also I am new to Scala.

I cant seem to get it to work. I guess syntactically I don't have things right or something.

The code I trying to use is here: https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/stats/distributions/Dirichlet.scala#L111

According to the code above: ExpFam[T,I] accepts two parameters T and I. I dont know what T and I are. Can T be a Dense Matrix ?

What I am doing is:

# Creating a matrix. The values are counts in my case.
val mat = DenseMatrix((1.0, 2.0, 3.0),(4.0, 5.0, 6.0))

# Then try to get sufficient stats and then MLE. I think this where I doing something wrong.
val diri = new ExpFam[DenseMatrix[Double],Int](mat)
println(diri.sufficientStatisticFor(mat))

Also if one has a data matrix like this DenseMatrix((1.0, 2.0, 3.0),(4.0, 5.0, 6.0)) how do estimate parameters (Dirichlet) in Scala.


Solution

  • I am not really very familiar with this aspect of breeze, but this works for me:

    val data = Seq(
      DenseVector(0.1, 0.1, 0.8),
      DenseVector(0.2, 0.3, 0.5),
      DenseVector(0.5, 0.1, 0.4),
      DenseVector(0.3, 0.3, 0.4)
    )
    
    val expFam = new Dirichlet.ExpFam(DenseVector.zeros[Double](3))
    
    val suffStat = data.foldLeft(expFam.emptySufficientStatistic){(a, x) => 
      a + expFam.sufficientStatisticFor(x)
    }
    
    val alphaHat = expFam.mle(suffStat)
    //DenseVector(2.9803000577558274, 2.325871404559782, 5.850530402841005)
    

    The result is very close to but not exactly the same as what I get with my own code for maximum likelihood estimation of Dirichlets. The difference probably just comes down to differences in the optimizer being used (I'm using the fixed point iteration (9) in section 1 of this paper by T. Minka) and the stopping criteria.

    Maybe there's a better way of doing this using the breeze api; if so, hopefully @dlwh or someone else more familiar with breeze will chime in.