predictionr-micemultinomialmarginal-effectsr-marginaleffects

How to pool average predictions from a multinomial regression per category in R


I want to obtain pooled average predictions per level of a categorical variable using a multinomial regression on multiple imputed data. My problem is that the pool() function seems to collapse all the levels to produce one single estimated average prediction instead of keeping the stratification per category. Here is a step-by-step example:

  1. Using rent data from library(catdata) I create 2 categorical variables
library("tidyverse")
library("catdata")
library("mice")
library("nnet")
library("marginaleffects")
data(rent)

rent<-rent%>%
  mutate(rent_cat=factor(case_when(
    rent<389.95~"Low price",
    rent>=389.95&rent<700.48~"Medium price",
    rent>=700.48~"High price"
  )))%>%
  mutate(size_cat=factor(case_when(
    size<53~"Small",
    size>=53&size<83~"Medium",
    size>=83~"Big"
  )))
  1. Run a multiple imputation by chained equation
micerent<-mice(rent,m=3)
  1. Obtain average prediction per level of a categorical variable

To be noted, based on this post it is better to build a function to extract predicted probabilties and to apply it to our imputed datasets in a complete format. I'm using the avg_predictions following suggestion from marignaleffects package tutorial.

fit_reg <- function(dat) {
  mod <- multinom(rent_cat~ size_cat,data=dat)
  out <- avg_predictions(mod,type="probs",by="size_cat")
  return(out)
}

micecompleterent<- complete(micerent, "all")

model<-lapply(micecompleterent, fit_reg)
  1. Pool the average predictions

There, instead of producing average prediction estimation per levels of size_cat, it seems to collapse all the stratified averages to give one pooled average prediction.

summary(pool(model),conf.int=T)

Pooled results

I also tried using newdata=datagrid(size_cat=c("Big","Medium","Small") in the predictions function from marginaleffects package but the result is the same:

fit_reg_response <- function(dat) {
  multinom(rent_cat~ size_cat,data=dat)
  out <- predictions(mod,newdata=datagrid(size_cat=c("Big","Medium","Small")))
  return(out)
}

Solution

  • There are two problems:

    1. Internally, to combine results from different model objects, mice::pool calls the tidy() function on all objects. It expects the output of tidy() to be a data.frame with a term column, but tidy does not return such a column in this case, because there is not a robust and general way to say what a “term” is in the context of average predictions.
    2. For good (but boring and irrelevant) reasons, tidy() returns a one-row average estimate when applied to the output of avg_predictions()

    There is not a super straightforward way to solve this, but here's a workaround:

    1. Assign a custom class name to the output of your fit_reg() function.
    2. Define a tidy() S3 method which creates the term column you want, based on the row identifiers you care about.

    Old code

    library("tidyverse")
    library("catdata")
    library("mice")
    library("nnet")
    library("marginaleffects")
    data(rent)
    
    rent<-rent%>%
      mutate(rent_cat=factor(case_when(
        rent<389.95~"Low price",
        rent>=389.95&rent<700.48~"Medium price",
        rent>=700.48~"High price"
      )))%>%
      mutate(size_cat=factor(case_when(
        size<53~"Small",
        size>=53&size<83~"Medium",
        size>=83~"Big"
      )))
    
    micerent <- mice(rent, m = 3, seed = 1024)
    

    New code

    fit_reg <- function(dat) {
        mod <- multinom(rent_cat ~ size_cat, data = dat, trace = FALSE)
        out <- avg_predictions(mod, type = "probs", by = "size_cat")
        # the next line is key
        class(out) <- c("custom_class", class(out))
        return(out)
    }
    
    micecompleterent <- complete(micerent, "all")
    
    model <- lapply(micecompleterent, fit_reg)
    tidy.custom_class <- function(x, ...) {
        # create a row identifier
        transform(x, term = paste0(group, size_cat))
    }
    
    summary(pool(model), conf.int = TRUE)
    #                 term     estimate    std.error   statistic       df
    # 1      High priceBig 6.131852e-01 0.0211548249 28.98559586 2041.778
    # 2   High priceMedium 1.867514e-01 0.0122504868 15.24440754 2041.778
    # 3    High priceSmall 2.399958e-06 0.0000685316  0.03501973 2041.778
    # 4       Low priceBig 5.283582e-02 0.0097171507  5.43737813 2041.778
    # 5    Low priceMedium 1.650216e-01 0.0116685744 14.14239902 2041.778
    # 6     Low priceSmall 6.223208e-01 0.0214465891 29.01723885 2041.778
    # 7    Medium priceBig 3.339790e-01 0.0204863976 16.30247432 2041.778
    # 8 Medium priceMedium 6.482270e-01 0.0150108250 43.18396572 2041.778
    # 9  Medium priceSmall 3.776768e-01 0.0214465633 17.61013157 2041.778
    #         p.value         2.5 %       97.5 %
    # 1 5.135684e-155  0.5716979181 0.6546724948
    # 2  8.872206e-50  0.1627266590 0.2107761683
    # 3  9.720674e-01 -0.0001319992 0.0001367991
    # 4  6.054466e-08  0.0337792608 0.0718923850
    # 5  2.168686e-43  0.1421380844 0.1879051861
    # 6 2.679125e-155  0.5802613238 0.6643802735
    # 7  2.898326e-56  0.2938025529 0.3741553884
    # 8 5.330330e-290  0.6187888240 0.6776650781
    # 9  9.846943e-65  0.3356173772 0.4197362256