I would like to have R's nnet::multinom
function being supported by the new marginaleffects
package, but marginaleffects::predictions()
relies on the predict()
methods supplied by the modeling packages to compute predicted values on both the response and link scale. In the case of nnet::multinom
, however, the predict()
method supplied by nnet
does not support predictions on the link scale - it only supports type="probs"
or type="class"
, https://github.com/vincentarelbundock/marginaleffects/issues/404. So I would like to redefine the nnet::multinom
predict.multinom
method so that it would also support type="link"
(in the original namespace of that package, so that also the marginaleffects
package would see it as having been redefined). Is there any way to accomplish this?
For reference, the predict.multinom
method (https://github.com/cran/nnet/blob/master/R/multinom.R) now looks like
predict.multinom <- function(object, newdata, type=c("class","probs"), ...)
{
if(!inherits(object, "multinom")) stop("not a \"multinom\" fit")
type <- match.arg(type)
if(missing(newdata)) Y <- fitted(object)
else {
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
Terms <- delete.response(object$terms)
m <- model.frame(Terms, newdata, na.action = na.omit,
xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
keep <- match(row.names(m), rn)
X <- model.matrix(Terms, m, contrasts = object$contrasts)
Y1 <- predict.nnet(object, X)
Y <- matrix(NA, nrow(newdata), ncol(Y1),
dimnames = list(rn, colnames(Y1)))
Y[keep, ] <- Y1
}
switch(type, class={
if(length(object$lev) > 2L)
Y <- factor(max.col(Y), levels=seq_along(object$lev),
labels=object$lev)
if(length(object$lev) == 2L)
Y <- factor(1 + (Y > 0.5), levels=1L:2L, labels=object$lev)
if(length(object$lev) == 0L)
Y <- factor(max.col(Y), levels=seq_along(object$lab),
labels=object$lab)
}, probs={})
drop(Y)
}
with predict.nnet
(https://github.com/cran/nnet/blob/master/R/nnet.R) being given by
predict.nnet <- function(object, newdata, type=c("raw","class"), ...)
{
if(!inherits(object, "nnet")) stop("object not of class \"nnet\"")
type <- match.arg(type)
if(missing(newdata)) z <- fitted(object)
else {
if(inherits(object, "nnet.formula")) { #
## formula fit
newdata <- as.data.frame(newdata)
rn <- row.names(newdata)
## work hard to predict NA for rows with missing data
Terms <- delete.response(object$terms)
m <- model.frame(Terms, newdata, na.action = na.omit,
xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
keep <- match(row.names(m), rn)
x <- model.matrix(Terms, m, contrasts = object$contrasts)
xint <- match("(Intercept)", colnames(x), nomatch=0L)
if(xint > 0L) x <- x[, -xint, drop=FALSE] # Bias term is used for intercepts
} else {
## matrix ... fit
if(is.null(dim(newdata)))
dim(newdata) <- c(1L, length(newdata)) # a row vector
x <- as.matrix(newdata) # to cope with dataframes
if(any(is.na(x))) stop("missing values in 'x'")
keep <- 1L:nrow(x)
rn <- rownames(x)
}
ntr <- nrow(x)
nout <- object$n[3L]
.C(VR_set_net,
as.integer(object$n), as.integer(object$nconn),
as.integer(object$conn), rep(0.0, length(object$wts)),
as.integer(object$nsunits), as.integer(0L),
as.integer(object$softmax), as.integer(object$censored))
z <- matrix(NA, nrow(newdata), nout,
dimnames = list(rn, dimnames(object$fitted.values)[[2L]]))
z[keep, ] <- matrix(.C(VR_nntest,
as.integer(ntr),
as.double(x),
tclass = double(ntr*nout),
as.double(object$wts)
)$tclass, ntr, nout)
.C(VR_unset_net)
}
switch(type, raw = z,
class = {
if(is.null(object$lev)) stop("inappropriate fit for class")
if(ncol(z) > 1L) object$lev[max.col(z)]
else object$lev[1L + (z > 0.5)]
})
}
I was hoping I could perhaps overwrite the predict.multinom
function by the predict.mblogit
function (https://github.com/melff/mclogit/blob/master/pkg/R/mblogit.R), or something close to it (probably some minor edits needed, due to the mblogit and nnet objects being structured slightly differently) :
predict.mblogit <- function(object, newdata=NULL,type=c("link","response"),se.fit=FALSE,...){
type <- match.arg(type)
mt <- terms(object)
rhs <- delete.response(mt)
if(missing(newdata)){
m <- object$model
na.act <- object$na.action
}
else{
m <- model.frame(rhs,data=newdata,na.action=na.exclude)
na.act <- attr(m,"na.action")
}
X <- model.matrix(rhs,m,
contrasts.arg=object$contrasts,
xlev=object$xlevels
)
rn <- rownames(X)
D <- object$D
XD <- X%x%D
rspmat <- function(x){
y <- t(matrix(x,nrow=nrow(D)))
colnames(y) <- rownames(D)
y
}
eta <- c(XD %*% coef(object))
eta <- rspmat(eta)
rownames(eta) <- rn
if(se.fit){
V <- vcov(object)
stopifnot(ncol(XD)==ncol(V))
}
if(type=="response") {
exp.eta <- exp(eta)
sum.exp.eta <- rowSums(exp.eta)
p <- exp.eta/sum.exp.eta
if(se.fit){
p.long <- as.vector(t(p))
s <- rep(1:nrow(X),each=nrow(D))
wX <- p.long*(XD - rowsum(p.long*XD,s)[s,,drop=FALSE])
se.p.long <- sqrt(rowSums(wX * (wX %*% V)))
se.p <- rspmat(se.p.long)
rownames(se.p) <- rownames(p)
if(is.null(na.act))
list(fit=p,se.fit=se.p)
else
list(fit=napredict(na.act,p),
se.fit=napredict(na.act,se.p))
}
else {
if(is.null(na.act)) p
else napredict(na.act,p)
}
}
else if(se.fit) {
se.eta <- sqrt(rowSums(XD * (XD %*% V)))
se.eta <- rspmat(se.eta)
eta <- eta[,-1,drop=FALSE]
se.eta <- se.eta[,-1,drop=FALSE]
if(is.null(na.act))
list(fit=eta,se.fit=se.eta)
else
list(fit=napredict(na.act,eta),
se.fit=napredict(na.act,se.eta))
}
else {
eta <- eta[,-1,drop=FALSE]
if(is.null(na.act)) eta
else napredict(na.act,eta)
}
}
Reproducible example of what I would like to achieve:
# data=SARS-CoV2 coronavirus variants (variant) through time (collection_date_num)
# in India, count=actual count (nr of sequenced genomes)
dat = read.csv("https://www.dropbox.com/s/u27cn44p5srievq/dat.csv?dl=1")
dat$collection_date = as.Date(dat$collection_date)
dat$collection_date_num = as.numeric(dat$collection_date) # numeric version of date, to convert back to date: as.Date(dat$collection_date_num, origin="1970-01-01")
dat$variant = factor(dat$variant)
# 1. multinom::net multinomial fit ####
library(nnet)
library(splines)
set.seed(1)
fit_nnet = nnet::multinom(variant ~ ns(collection_date_num, df=2),
weights=count, data=dat)
summary(fit_nnet)
# 2. predicted probabilities & 95% CLs at maximum date calculated using emmeans: works, but slow for large models ####
library(emmeans)
multinom_emmeans = emmeans(fit_nnet, ~ variant,
mode = "prob",
at=list(collection_date_num =
max(dat$collection_date_num)))
multinom_emmeans
# variant prob SE df lower.CL upper.CL
# Alpha 0.00e+00 0.00e+00 33 0.00e+00 0.00e+00
# Beta 0.00e+00 0.00e+00 33 0.00e+00 0.00e+00
# Delta 7.73e-06 1.17e-06 33 5.34e-06 1.01e-05
# Omicron (BA.1) 1.82e-04 6.42e-05 33 5.14e-05 3.13e-04
# Omicron (BA.2) 1.76e-01 7.45e-03 33 1.61e-01 1.91e-01
# Omicron (BA.2.74) 9.03e-02 7.98e-03 33 7.41e-02 1.07e-01
# Omicron (BA.2.75) 1.68e-01 1.90e-02 33 1.30e-01 2.07e-01
# Omicron (BA.2.76) 2.89e-01 1.35e-02 33 2.62e-01 3.16e-01
# Omicron (BA.3) 1.34e-02 2.10e-03 33 9.10e-03 1.76e-02
# Omicron (BA.4) 1.67e-02 2.47e-03 33 1.17e-02 2.17e-02
# Omicron (BA.5) 2.03e-01 1.08e-02 33 1.81e-01 2.25e-01
# Other 4.23e-02 3.15e-03 33 3.59e-02 4.87e-02
#
# Confidence level used: 0.95
# 3. predicted probabilities & 95% CLs at maximum date calculated using marginaleffects: does not work because of lack of a predict.multinom method supporting type="link" ####
library(marginaleffects)
multinom_preds_marginaleffects = predictions(fit_nnet,
newdata = datagrid(collection_date_num =
max(dat$collection_date_num)),
type="link", # not supported by predict.multinom
transform_post = insight::link_inverse(fit_nnet))
# Error: The `type` argument for models of class `multinom` must be an element of: probs
# PS: desired output should match emmeans output above
The way to redefine a method in a package is to use assignInNamespace
. However, assuming this is intended to be part of another package that will eventually be made public, it's a bit rude since you're trampling over someone else's code. In particular, if you intend to put it on CRAN, you might run into issues with convincing the CRAN reviewers that it's ok.
A better solution would be to create a wrapper method that calls the original method. For this you'll also need to create a wrapper multinom
function, so that the correct package namespace is found. A sketch implementation is shown below.
multinom <- function(...)
{
nnet::multinom(...)
}
# this is the link function for multinomial
# = generalized logit
inverse_softMax <- function(mu) {
log_mu <- log(mu)
return(sweep(log_mu, 1, STATS=rowMeans(log_mu), FUN="-")) # we let the log(odds) sum to zero - these predictions are referred to as type="latent" in the emmeans package
}
predict.multinom <- function(object, newdata, type=c("probs", "response", "latent", "link") # probs==response, latent==link
{
type <- match.arg(type)
if (type == "probs"|type == "response")
return(nnet:::predict.multinom(object, newdata, type="probs"))
mu <- nnet:::predict.multinom(object, newdata, type="probs")
return(inverse_softMax(mu))
}