rggplot2naggproto

NA handling in **ggplot2** "Stat" in new ggproto


I'm trying to create a new Stat in a ggplot2 extension. I don't understand where NA values are removed, or how to write my new ggproto to accommodate other NA handling options. Where and how are NA's removed when creating a Stat, and how can this be over-ridden? I am staring at line 93 remove_missing() and struggling to see how to short-circuit this: https://github.com/tidyverse/ggplot2/blob/main/R/stat-.r

I'd like my new stat to implement some other NA handling, e.g. here I'm trying to replace NA with 999.

library(patchwork)

# silly example for a new function to include in a new Stat
mymean <- function(x){
  x[is.na(x)] <- 999
  mean(x) 
}

# dataset with some intentional NA values
mydata <- data.frame(
  x_mydata = c(1, NA, 1, 2, 2, 2, 3, 3, NA),
  y_mydata = c(1,2,NA,4,5, 6,7, NA, NA)
)

all.equal(mymean(mydata$y_mydata),  mean(mydata$y_mydata, na.rm = TRUE))

p1 <- ggplot(data = mydata, aes(x= x_mydata, y= y_mydata)) +
  stat_summary(fun = mean, geom="bar")

p2 <- ggplot(data = mydata, aes(x= x_mydata, y= y_mydata)) +
  stat_summary(fun = mymean, geom="bar")

# print together with patchwork
p1 + p2

Solution

  • Simple version

    You can define your own Stat (inherited from StatSummary in my example below, since that's what you used in the question, though you can also inherit from other stats if they suit your needs better), with a modified version of the compute_layer() function that makes use of your own function for handling NA values, instead of the original remove_missing() function.

    Replacement function for remove_missing:

    replace_missing <- function(x) {
      x[is.na(x)] <- 999
      x
    }
    

    Define own Stat based on StatSummary (note: all non-exported functions used in the original compute_layer have been prefaced with <relevant package namespace>:::, so that the new version can still use them):

    StatNew <- ggproto("StatNew", StatSummary,
                       compute_layer = function(self, data, params, layout) {
                         ggplot2:::check_required_aesthetics(
                           self$required_aes,
                           c(names(data), names(params)),
                           ggplot2:::snake_class(self)
                         )
                         
                         required_aes <- intersect(
                           names(data),
                           unlist(strsplit(self$required_aes, "|", fixed = TRUE))
                         )
    
                         data <- replace_missing(data) # instead of remove_missing()
                         
                         params <- params[intersect(names(params), self$parameters())]
                         
                         args <- c(list(data = quote(data), scales = quote(scales)), params)
                         ggplot2:::dapply(data, "PANEL", function(data) {
                           scales <- layout$get_scales(data$PANEL[1])
                           rlang:::try_fetch(
                             rlang:::inject(self$compute_panel(data = data, scales = scales, !!!params)),
                             error = function(cnd) {
                               cli::cli_warn("Computation failed in {.fn {ggplot2:::snake_class(self)}}", parent = cnd)
                               data_frame0()
                             }
                           )
                         })
                       })
    

    Define own stat_* function to call on new Stat:

    stat_new <- function(mapping = NULL, data = NULL, geom = "pointrange", position = "identity", 
                         ..., fun.data = NULL, fun = NULL, fun.max = NULL, fun.min = NULL, 
                         fun.args = list(), na.rm = FALSE, orientation = NA, show.legend = NA, 
                         inherit.aes = TRUE) {
      
      layer(data = data, mapping = mapping, stat = StatNew, 
            geom = geom, position = position, show.legend = show.legend, 
            inherit.aes = inherit.aes, 
            params = rlang:::list2(fun.data = fun.data, fun = fun, fun.max = fun.max, fun.min = fun.min, 
                                   fun.args = fun.args, na.rm = na.rm, orientation = orientation, 
                                   ...))
    }
    

    Usage:

    p3 <- ggplot(data = mydata, aes(x= x_mydata, y= y_mydata)) +
      stat_new(fun = mean, geom="bar")
    
    p3
    

    p3


    More flexible version

    Suppose you want the flexibility of controlling the replacement value within stat_*(), rather than having it hardcoded as 999 for all eternity, the following changes should suffice:

    Add parameter in replacement function to take on different replacement values:

    replace_missing <- function(x, na.replace = 999) {
      x[is.na(x)] <- na.replace
      x
    }
    

    Add na.replace as one of the extra parameters expected by your Stat:

    StatNew <- ggproto("StatNew", StatSummary,
    
                       # original ones expected by StatSummary are only na.rm & orientation
                       extra_params = c("na.rm", "orientation", "na.replace"),
    
                       compute_layer = function(self, data, params, layout) {
                         ggplot2:::check_required_aesthetics(
                           self$required_aes,
                           c(names(data), names(params)),
                           ggplot2:::snake_class(self)
                         )
                         
                         required_aes <- intersect(
                           names(data),
                           unlist(strsplit(self$required_aes, "|", fixed = TRUE))
                         )
    
                         data <- replace_missing(data, params$na.replace) # note na.replace parameter's usage here
                         
                         params <- params[intersect(names(params), self$parameters())]
                         
                         args <- c(list(data = quote(data), scales = quote(scales)), params)
                         ggplot2:::dapply(data, "PANEL", function(data) {
                           scales <- layout$get_scales(data$PANEL[1])
                           rlang:::try_fetch(
                             rlang:::inject(self$compute_panel(data = data, scales = scales, !!!params)),
                             error = function(cnd) {
                               cli::cli_warn("Computation failed in {.fn {ggplot2:::snake_class(self)}}", parent = cnd)
                               data_frame0()
                             }
                           )
                         })
                       })
    

    Include na.replace as a parameter for stat_* function:

    stat_new <- function(mapping = NULL, data = NULL, geom = "pointrange", position = "identity", 
                         ..., fun.data = NULL, fun = NULL, fun.max = NULL, fun.min = NULL, 
                         fun.args = list(), na.rm = FALSE, orientation = NA, show.legend = NA, 
                         inherit.aes = TRUE, na.replace = 999) { # set default value as 999
      
      layer(data = data, mapping = mapping, stat = StatNew, 
            geom = geom, position = position, show.legend = show.legend, 
            inherit.aes = inherit.aes, 
            params = rlang:::list2(fun.data = fun.data, fun = fun, fun.max = fun.max, fun.min = fun.min, 
                                   fun.args = fun.args, na.rm = na.rm, orientation = orientation, 
                                   na.replace = na.replace, # must include here too
                                   ...))
    }
    

    Usage:

    p4 <- ggplot(data = mydata, aes(x= x_mydata, y= y_mydata)) +
      stat_new(fun = mean, geom="bar")
    p4 # no difference from before, as default replacement value of 999 is used
    
    p5 <- ggplot(data = mydata, aes(x= x_mydata, y= y_mydata)) +
      stat_new(fun = mean, geom="bar", na.replace = 50)
    p5 # different from p3/p4
    

    p5