rcppmultinomialhessian-matrix

Faster way to calculate the Hessian / Fisher Information Matrix of a nnet::multinom multinomial regression in R using Rcpp & Kronecker products


It appears that for larger nnet::multinom multinomial regression models (with a few thousand coefficients), calculating the Hessian (the matrix of second derivatives of the negative log likelihood, also known as the observed Fisher information matrix) becomes super slow, which then prevents me from calculating the variance-covariance matrix & allowing me to calculate confidence intervals on model predictions.

It seems the culprit is the following pure R function - it seems it uses some code to calculate the Fisher information matrix analytically using code contributed by David Firth : https://github.com/cran/nnet/blob/master/R/vcovmultinom.R

multinomHess = function (object, Z = model.matrix(object)) 
{
    probs <- object$fitted
    coefs <- coef(object)
    if (is.vector(coefs)) {
        coefs <- t(as.matrix(coefs))
        probs <- cbind(1 - probs, probs)
    }
    coefdim <- dim(coefs)
    p <- coefdim[2L]
    k <- coefdim[1L]
    ncoefs <- k * p
    kpees <- rep(p, k)
    n <- dim(Z)[1L]

##  Now compute the observed (= expected, in this case) information,
##  e.g. as in T Amemiya "Advanced Econometrics" (1985) pp 295-6.
##  Here i and j are as in Amemiya, and x, xbar are vectors
##  specific to (i,j) and to i respectively.

    info <- matrix(0, ncoefs, ncoefs)
    Names <- dimnames(coefs)
    if (is.null(Names[[1L]])) 
        Names <- Names[[2L]]
    else Names <- as.vector(outer(Names[[2L]], Names[[1L]], function(name2, 
        name1) paste(name1, name2, sep = ":")))
    dimnames(info) <- list(Names, Names)
    x0 <- matrix(0, p, k + 1L)
    row.totals <- object$weights
    for (i in seq_len(n)) {
        Zi <- Z[i, ]
        xbar <- rep(Zi, times=k) * rep(probs[i, -1, drop=FALSE], times=kpees)
        for (j in seq_len(k + 1)) {
            x <- x0
            x[, j] <- Zi
            x <- x[, -1, drop = FALSE]
            x <- x - xbar
            dim(x) <- c(1, ncoefs)
            info <- info + (row.totals[i] * probs[i, j] * crossprod(x))
        }
    }
    info
}

The info in the Advanced Econometrics book that is referenced states enter image description here enter image description here

From this explanation, we can see that the Hessian indeed is just given by the sum of a bunch of crossproducts. I also saw this and this in terms of derivation of how to calculate the Hessian matrix of a multinomial regression model, which may be even more elegant and efficient, as the Hessian is there calculated based on a sum of Kronecker products.

For a smallish nnet::multinom model (in which I am modelling the frequency of different SARS-CoV2 lineages through time) the provided function runs quickly :

library(nnet)
library(splines)
download.file("https://www.dropbox.com/s/gt0yennn2gkg3rd/smallmodel.RData?dl=1",
              "smallmodel.RData", 
              method = "auto", mode="wb")
load("smallmodel.RData")
length(fit_multinom_small$lev) # k=12 outcome levels
dim(coef(fit_multinom_small)) # 11 x 3 = (k-1) x p = 33 coefs
system.time(hess <- nnet:::multinomHess(fit_multinom_small)) # 0.11s
dim(hess) # 33 33

but doing this for a large model takes more than 2 hours (even though the model itself fits in ca. 1 minute) (again modelling the frequency of different SARS-CoV2 lineages through time, but now across different continents / countries) :

download.file("https://www.dropbox.com/s/mpz08jj7fmubd68/bigmodel.RData?dl=1",
              "bigmodel.RData", 
              method = "auto", mode="wb")
load("bigmodel.RData")
length(fit_global_multi_last3m$lev) # k=20 outcome levels
dim(coef(fit_global_multi_last3m)) # 19 x 229 = (k-1) x p = 4351 coefficients
system.time(hess <- nnet:::multinomHess(fit_global_multi_last3m)) # takes forever

I was now looking for ways to speed up the above function.

The obvious attempt could be to port it to Rcpp, but unfortunately I am not so experienced in this. Anybody any thoughts?

EDIT: From the info here and here, it appears that calculating the Hessian for a multinomial fit should just come down to calculating a sum of Kronecker products, which we can just do from R using efficient matrix algebra, but right now I am unsure how to include my total row counts fit$weights. Anybody any idea?

download.file("https://www.dropbox.com/s/gt0yennn2gkg3rd/smallmodel.RData?dl=1",
                      "smallmodel.RData", 
                      method = "auto", mode="wb")

load("smallmodel.RData")
library(nnet)
length(fit_multinom_small$lev) # k=12 outcome levels
dim(coef(fit_multinom_small)) # 11 x 3 = (k-1) x p = 33 coefs

fit = fit_multinom_small

Z = model.matrix(fit)
P = fitted(fit)[, -1, drop=F]
k = ncol(P) # nr of outcome categories-1
p = ncol(Z) # nr of parameters
n = nrow(Z) # nr of observations
ncoefs = k*p
library(fastmatrix)

# Fisher information matrix
info <- matrix(0, ncoefs, ncoefs)
for (i in 1:n) { # sum over observations
info = info + kronecker.prod(diag(P[i,]) - tcrossprod(P[i,]), tcrossprod(Z[i,]))
}

Solution

  • Figured it out in the end & was able to calculate the observed Fisher information matrix using Kronecker products, as well as port that bit to Rcpp, using Armadillo classes (full disclosure: I made that Rcpp port just using OpenAI's code-davinci / Codex, https://openai.com/blog/openai-codex/, and surprisingly it worked straight out of the box - AI is getting better every day; parallelReduce could still be used to parallelize the accumulation I presume; the function was faster than an equivalent RcppEigen implementation I tried). The mistake I made was that the formula above was the observed Fisher information for a single observation, so I had to accumulate over observations & I also had to take into account my total row counts.

    Rcpp function:

    // RcppArmadillo utility function to calculate observed Fisher 
    // information matrix of multinomial fit, with 
    // probs=fitted probabilities (with 1st category/column dropped)
    // Z = model matrix
    // row_totals = row totals
    // We do this using Kronecker products, as in
    // https://ieeexplore.ieee.org/abstract/document/1424458
    // B. Krishnapuram; L. Carin; M.A.T. Figueiredo; A.J. Hartemink
    // Sparse multinomial logistic regression: fast algorithms and
    // generalization bounds
    // IEEE Transactions on Pattern Analysis and Machine
    // Intelligence ( Volume: 27, Issue: 6, June 2005)
    
    #include <RcppArmadillo.h>
    
    using namespace arma;
    
    // [[Rcpp::depends(RcppArmadillo)]]
    // [[Rcpp::export]]
    arma::mat calc_infmatrix_RcppArma(arma::mat probs, arma::mat Z, arma::vec row_totals) {
      int n = Z.n_rows;
      int p = Z.n_cols;
      int k = probs.n_cols;
      int ncoefs = k * p;
      arma::mat info = arma::zeros<arma::mat>(ncoefs, ncoefs);
      arma::mat diag_probs;
      arma::mat tcrossprod_probs;
      arma::mat tcrossprod_Z;
      arma::mat kronecker_prod;
      for (int i = 0; i < n; i++) {
        diag_probs = arma::diagmat(probs.row(i));
        tcrossprod_probs = arma::trans(probs.row(i)) * probs.row(i);
        tcrossprod_Z = (arma::trans(Z.row(i)) * Z.row(i)) * row_totals(i);
        kronecker_prod = arma::kron(diag_probs - tcrossprod_probs, tcrossprod_Z);
        info += kronecker_prod;
      }
      return info;
    }
    

    saved as "calc_infmatrix_arma.cpp".

    library(Rcpp)
    library(RcppArmadillo)
    sourceCpp("calc_infmatrix_arma.cpp")
    

    R wrapper function :

    # Function to calculate Hessian / observed Fisher information
    # matrix of nnet::multinom multinomial fit object
    fastmultinomHess <- function(object, Z = model.matrix(object)) {
      
      probs <- object$fitted # predicted probabilities, avoid napredict from fitted.default
    
      coefs <- coef(object)
      if (is.vector(coefs)){ # ie there are only 2 response categories
        coefs <- t(as.matrix(coefs))
        probs <- cbind(1 - probs, probs)
      }
      coefdim <- dim(coefs)
      p <- coefdim[2L] # nr of parameters
      k <- coefdim[1L] # nr out outcome categories-1
      ncoefs <- k * p # nr of coefficients
      n <- dim(Z)[1L] # nr of observations
      
      #  Now compute the Hessian = the observed 
      #  (= expected, in this case) 
      #  Fisher information matrix
        
      info <- calc_infmatrix_RcppArma(probs = probs[, -1, drop=F], 
                                      Z = Z, 
                                      row_totals = object$weights)
    
      Names <- dimnames(coefs)
      if (is.null(Names[[1L]])) Names <- Names[[2L]] else Names <- as.vector(outer(Names[[2L]], Names[[1L]],
                                    function(name2, name1)
                                      paste(name1, name2, sep = ":")))
      dimnames(info) <- list(Names, Names)
    
      return(info)
    }
    

    For my larger model this now calculates in 100s instead of >2 hours, so almost 80 times faster :

    download.file("https://www.dropbox.com/s/mpz08jj7fmubd68/bigmodel.RData?dl=1",
                  "bigmodel.RData", 
                  method = "auto", mode="wb")
    load("bigmodel.RData")
    object = fit_global_multi_last3m # large nnet::multinom fit
    system.time(info <- fastmultinomHess(object, Z = model.matrix(object))) # 103s
    system.time(info <- nnet:::multinomHess(object, Z = model.matrix(object))) # 8127s = 2.25h
    

    A pure R version of the calc_infmatrix function (ca. 5x slower than the Rcpp function above) would be

    # Utility function to calculate observed Fisher information matrix
    # of multinomial fit, with 
    # probs=fitted probabilities (with 1st category/column dropped)
    # Z = model matrix
    # row_totals = row totals
        calc_infmatrix = function(probs, Z, row_totals) {
          require(fastmatrix) # for kronecker.prod Kronecker product function
          
          n <- nrow(Z)
          p <- ncol(Z)
          k <- ncol(probs)
          ncoefs <- k * p
          info <- matrix(0, ncoefs, ncoefs)
          for (i in 1:n) {
            info <- info + kronecker.prod((diag(probs[i,]) - tcrossprod(probs[i,])), tcrossprod(Z[i,])*row_totals[i] )
          }
          return(info)  
        }