I want to obtain pooled average predictions per level of a categorical variable using a multinomial regression on multiple imputed data. My problem is that the pool()
function seems to collapse all the levels to produce one single estimated average prediction instead of keeping the stratification per category. Here is a step-by-step example:
rent
data from library(catdata)
I create 2 categorical variableslibrary("tidyverse")
library("catdata")
library("mice")
library("nnet")
library("marginaleffects")
data(rent)
rent<-rent%>%
mutate(rent_cat=factor(case_when(
rent<389.95~"Low price",
rent>=389.95&rent<700.48~"Medium price",
rent>=700.48~"High price"
)))%>%
mutate(size_cat=factor(case_when(
size<53~"Small",
size>=53&size<83~"Medium",
size>=83~"Big"
)))
micerent<-mice(rent,m=3)
To be noted, based on this post it is better to build a function to extract predicted probabilties and to apply it to our imputed datasets in a complete format. I'm using the avg_predictions
following suggestion from marignaleffects
package tutorial.
fit_reg <- function(dat) {
mod <- multinom(rent_cat~ size_cat,data=dat)
out <- avg_predictions(mod,type="probs",by="size_cat")
return(out)
}
micecompleterent<- complete(micerent, "all")
model<-lapply(micecompleterent, fit_reg)
There, instead of producing average prediction estimation per levels of size_cat
, it seems to collapse all the stratified averages to give one pooled average prediction.
summary(pool(model),conf.int=T)
I also tried using newdata=datagrid(size_cat=c("Big","Medium","Small")
in the predictions
function from marginaleffects
package but the result is the same:
fit_reg_response <- function(dat) {
multinom(rent_cat~ size_cat,data=dat)
out <- predictions(mod,newdata=datagrid(size_cat=c("Big","Medium","Small")))
return(out)
}
There are two problems:
mice::pool
calls the tidy()
function on all objects. It expects the output of tidy()
to be a data.frame with a term
column, but tidy
does not return such a column in this case, because there is not a robust and general way to say what a “term” is in the context of average predictions.tidy()
returns a one-row average estimate when applied to the output of avg_predictions()
There is not a super straightforward way to solve this, but here's a workaround:
fit_reg()
function.tidy()
S3 method which creates the term
column you want, based on the row identifiers you care about.library("tidyverse")
library("catdata")
library("mice")
library("nnet")
library("marginaleffects")
data(rent)
rent<-rent%>%
mutate(rent_cat=factor(case_when(
rent<389.95~"Low price",
rent>=389.95&rent<700.48~"Medium price",
rent>=700.48~"High price"
)))%>%
mutate(size_cat=factor(case_when(
size<53~"Small",
size>=53&size<83~"Medium",
size>=83~"Big"
)))
micerent <- mice(rent, m = 3, seed = 1024)
fit_reg <- function(dat) {
mod <- multinom(rent_cat ~ size_cat, data = dat, trace = FALSE)
out <- avg_predictions(mod, type = "probs", by = "size_cat")
# the next line is key
class(out) <- c("custom_class", class(out))
return(out)
}
micecompleterent <- complete(micerent, "all")
model <- lapply(micecompleterent, fit_reg)
tidy.custom_class <- function(x, ...) {
# create a row identifier
transform(x, term = paste0(group, size_cat))
}
summary(pool(model), conf.int = TRUE)
# term estimate std.error statistic df
# 1 High priceBig 6.131852e-01 0.0211548249 28.98559586 2041.778
# 2 High priceMedium 1.867514e-01 0.0122504868 15.24440754 2041.778
# 3 High priceSmall 2.399958e-06 0.0000685316 0.03501973 2041.778
# 4 Low priceBig 5.283582e-02 0.0097171507 5.43737813 2041.778
# 5 Low priceMedium 1.650216e-01 0.0116685744 14.14239902 2041.778
# 6 Low priceSmall 6.223208e-01 0.0214465891 29.01723885 2041.778
# 7 Medium priceBig 3.339790e-01 0.0204863976 16.30247432 2041.778
# 8 Medium priceMedium 6.482270e-01 0.0150108250 43.18396572 2041.778
# 9 Medium priceSmall 3.776768e-01 0.0214465633 17.61013157 2041.778
# p.value 2.5 % 97.5 %
# 1 5.135684e-155 0.5716979181 0.6546724948
# 2 8.872206e-50 0.1627266590 0.2107761683
# 3 9.720674e-01 -0.0001319992 0.0001367991
# 4 6.054466e-08 0.0337792608 0.0718923850
# 5 2.168686e-43 0.1421380844 0.1879051861
# 6 2.679125e-155 0.5802613238 0.6643802735
# 7 2.898326e-56 0.2938025529 0.3741553884
# 8 5.330330e-290 0.6187888240 0.6776650781
# 9 9.846943e-65 0.3356173772 0.4197362256