It appears that for larger nnet::multinom
multinomial regression models (with a few thousand coefficients), calculating the Hessian (the matrix of second derivatives of the negative log likelihood, also known as the observed Fisher information matrix) becomes super slow, which then prevents me from calculating the variance-covariance matrix & allowing me to calculate confidence intervals on model predictions.
It seems the culprit is the following pure R function - it seems it uses some code to calculate the Fisher information matrix analytically using code contributed by David Firth : https://github.com/cran/nnet/blob/master/R/vcovmultinom.R
multinomHess = function (object, Z = model.matrix(object))
{
probs <- object$fitted
coefs <- coef(object)
if (is.vector(coefs)) {
coefs <- t(as.matrix(coefs))
probs <- cbind(1 - probs, probs)
}
coefdim <- dim(coefs)
p <- coefdim[2L]
k <- coefdim[1L]
ncoefs <- k * p
kpees <- rep(p, k)
n <- dim(Z)[1L]
## Now compute the observed (= expected, in this case) information,
## e.g. as in T Amemiya "Advanced Econometrics" (1985) pp 295-6.
## Here i and j are as in Amemiya, and x, xbar are vectors
## specific to (i,j) and to i respectively.
info <- matrix(0, ncoefs, ncoefs)
Names <- dimnames(coefs)
if (is.null(Names[[1L]]))
Names <- Names[[2L]]
else Names <- as.vector(outer(Names[[2L]], Names[[1L]], function(name2,
name1) paste(name1, name2, sep = ":")))
dimnames(info) <- list(Names, Names)
x0 <- matrix(0, p, k + 1L)
row.totals <- object$weights
for (i in seq_len(n)) {
Zi <- Z[i, ]
xbar <- rep(Zi, times=k) * rep(probs[i, -1, drop=FALSE], times=kpees)
for (j in seq_len(k + 1)) {
x <- x0
x[, j] <- Zi
x <- x[, -1, drop = FALSE]
x <- x - xbar
dim(x) <- c(1, ncoefs)
info <- info + (row.totals[i] * probs[i, j] * crossprod(x))
}
}
info
}
The info in the Advanced Econometrics book that is referenced states
From this explanation, we can see that the Hessian indeed is just given by the sum of a bunch of crossproducts. I also saw this and this in terms of derivation of how to calculate the Hessian matrix of a multinomial regression model, which may be even more elegant and efficient, as the Hessian is there calculated based on a sum of Kronecker products.
For a smallish nnet::multinom
model (in which I am modelling the frequency of different SARS-CoV2 lineages through time) the provided function runs quickly :
library(nnet)
library(splines)
download.file("https://www.dropbox.com/s/gt0yennn2gkg3rd/smallmodel.RData?dl=1",
"smallmodel.RData",
method = "auto", mode="wb")
load("smallmodel.RData")
length(fit_multinom_small$lev) # k=12 outcome levels
dim(coef(fit_multinom_small)) # 11 x 3 = (k-1) x p = 33 coefs
system.time(hess <- nnet:::multinomHess(fit_multinom_small)) # 0.11s
dim(hess) # 33 33
but doing this for a large model takes more than 2 hours (even though the model itself fits in ca. 1 minute) (again modelling the frequency of different SARS-CoV2 lineages through time, but now across different continents / countries) :
download.file("https://www.dropbox.com/s/mpz08jj7fmubd68/bigmodel.RData?dl=1",
"bigmodel.RData",
method = "auto", mode="wb")
load("bigmodel.RData")
length(fit_global_multi_last3m$lev) # k=20 outcome levels
dim(coef(fit_global_multi_last3m)) # 19 x 229 = (k-1) x p = 4351 coefficients
system.time(hess <- nnet:::multinomHess(fit_global_multi_last3m)) # takes forever
I was now looking for ways to speed up the above function.
The obvious attempt could be to port it to Rcpp, but unfortunately I am not so experienced in this. Anybody any thoughts?
EDIT: From the info here and here, it appears that calculating the Hessian for a multinomial fit should just come down to calculating a sum of Kronecker products, which we can just do from R using efficient matrix algebra, but right now I am unsure how to include my total row counts fit$weights
. Anybody any idea?
download.file("https://www.dropbox.com/s/gt0yennn2gkg3rd/smallmodel.RData?dl=1",
"smallmodel.RData",
method = "auto", mode="wb")
load("smallmodel.RData")
library(nnet)
length(fit_multinom_small$lev) # k=12 outcome levels
dim(coef(fit_multinom_small)) # 11 x 3 = (k-1) x p = 33 coefs
fit = fit_multinom_small
Z = model.matrix(fit)
P = fitted(fit)[, -1, drop=F]
k = ncol(P) # nr of outcome categories-1
p = ncol(Z) # nr of parameters
n = nrow(Z) # nr of observations
ncoefs = k*p
library(fastmatrix)
# Fisher information matrix
info <- matrix(0, ncoefs, ncoefs)
for (i in 1:n) { # sum over observations
info = info + kronecker.prod(diag(P[i,]) - tcrossprod(P[i,]), tcrossprod(Z[i,]))
}
Figured it out in the end & was able to calculate the observed Fisher information matrix using Kronecker products, as well as port that bit to Rcpp, using Armadillo classes (full disclosure: I made that Rcpp port just using OpenAI's code-davinci / Codex, https://openai.com/blog/openai-codex/, and surprisingly it worked straight out of the box - AI is getting better every day; parallelReduce
could still be used to parallelize the accumulation I presume; the function was faster than an equivalent RcppEigen
implementation I tried). The mistake I made was that the formula above was the observed Fisher information for a single observation, so I had to accumulate over observations & I also had to take into account my total row counts.
Rcpp function:
// RcppArmadillo utility function to calculate observed Fisher
// information matrix of multinomial fit, with
// probs=fitted probabilities (with 1st category/column dropped)
// Z = model matrix
// row_totals = row totals
// We do this using Kronecker products, as in
// https://ieeexplore.ieee.org/abstract/document/1424458
// B. Krishnapuram; L. Carin; M.A.T. Figueiredo; A.J. Hartemink
// Sparse multinomial logistic regression: fast algorithms and
// generalization bounds
// IEEE Transactions on Pattern Analysis and Machine
// Intelligence ( Volume: 27, Issue: 6, June 2005)
#include <RcppArmadillo.h>
using namespace arma;
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::mat calc_infmatrix_RcppArma(arma::mat probs, arma::mat Z, arma::vec row_totals) {
int n = Z.n_rows;
int p = Z.n_cols;
int k = probs.n_cols;
int ncoefs = k * p;
arma::mat info = arma::zeros<arma::mat>(ncoefs, ncoefs);
arma::mat diag_probs;
arma::mat tcrossprod_probs;
arma::mat tcrossprod_Z;
arma::mat kronecker_prod;
for (int i = 0; i < n; i++) {
diag_probs = arma::diagmat(probs.row(i));
tcrossprod_probs = arma::trans(probs.row(i)) * probs.row(i);
tcrossprod_Z = (arma::trans(Z.row(i)) * Z.row(i)) * row_totals(i);
kronecker_prod = arma::kron(diag_probs - tcrossprod_probs, tcrossprod_Z);
info += kronecker_prod;
}
return info;
}
saved as "calc_infmatrix_arma.cpp"
.
library(Rcpp)
library(RcppArmadillo)
sourceCpp("calc_infmatrix_arma.cpp")
R wrapper function :
# Function to calculate Hessian / observed Fisher information
# matrix of nnet::multinom multinomial fit object
fastmultinomHess <- function(object, Z = model.matrix(object)) {
probs <- object$fitted # predicted probabilities, avoid napredict from fitted.default
coefs <- coef(object)
if (is.vector(coefs)){ # ie there are only 2 response categories
coefs <- t(as.matrix(coefs))
probs <- cbind(1 - probs, probs)
}
coefdim <- dim(coefs)
p <- coefdim[2L] # nr of parameters
k <- coefdim[1L] # nr out outcome categories-1
ncoefs <- k * p # nr of coefficients
n <- dim(Z)[1L] # nr of observations
# Now compute the Hessian = the observed
# (= expected, in this case)
# Fisher information matrix
info <- calc_infmatrix_RcppArma(probs = probs[, -1, drop=F],
Z = Z,
row_totals = object$weights)
Names <- dimnames(coefs)
if (is.null(Names[[1L]])) Names <- Names[[2L]] else Names <- as.vector(outer(Names[[2L]], Names[[1L]],
function(name2, name1)
paste(name1, name2, sep = ":")))
dimnames(info) <- list(Names, Names)
return(info)
}
For my larger model this now calculates in 100s instead of >2 hours, so almost 80 times faster :
download.file("https://www.dropbox.com/s/mpz08jj7fmubd68/bigmodel.RData?dl=1",
"bigmodel.RData",
method = "auto", mode="wb")
load("bigmodel.RData")
object = fit_global_multi_last3m # large nnet::multinom fit
system.time(info <- fastmultinomHess(object, Z = model.matrix(object))) # 103s
system.time(info <- nnet:::multinomHess(object, Z = model.matrix(object))) # 8127s = 2.25h
A pure R version of the calc_infmatrix
function (ca. 5x slower than the Rcpp function above) would be
# Utility function to calculate observed Fisher information matrix
# of multinomial fit, with
# probs=fitted probabilities (with 1st category/column dropped)
# Z = model matrix
# row_totals = row totals
calc_infmatrix = function(probs, Z, row_totals) {
require(fastmatrix) # for kronecker.prod Kronecker product function
n <- nrow(Z)
p <- ncol(Z)
k <- ncol(probs)
ncoefs <- k * p
info <- matrix(0, ncoefs, ncoefs)
for (i in 1:n) {
info <- info + kronecker.prod((diag(probs[i,]) - tcrossprod(probs[i,])), tcrossprod(Z[i,])*row_totals[i] )
}
return(info)
}