rcartdecision-treerpart

How to plot the recursive partitioning from the rpart package


I want to plot a partition of a two-dimensional covariate space constructed by recursive binary splitting. To be more precise, I would like to write a function that replicates the following graph (taken from Elements of Statistical Learning, pag. 306):

enter image description here

Displayed above is a two-dimensional covariate space and a partition obtained by recursive binary splitting the space using axis-aligned splits (what is also called a CART algorithm). What I want to implement is a function that takes the output of the rpart function and generates such plot.

It follows some example code:

## Generating data.
set.seed(1975)

n <- 5000
p <- 2

X <- matrix(sample(seq(0, 1, by = 0.01), n * p, replace = TRUE), ncol = p)
Y <- X[, 1] + 2 * X[, 2] + rnorm(n)

## Building tree.
tree <- rpart(Y ~ ., data = data.frame(Y, X), method = "anova", control = rpart.control(cp = 0, maxdepth = 2))

Navigating SO I found this function:

rpart_splits <- function(fit, digits = getOption("digits")) {
  splits <- fit$splits
  if (!is.null(splits)) {
    ff <- fit$frame
    is.leaf <- ff$var == "<leaf>"
    n <- nrow(splits)
    nn <- ff$ncompete + ff$nsurrogate + !is.leaf
    ix <- cumsum(c(1L, nn))
    ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
    type <- rep.int("surrogate", n)
    type[ix_prim[ix_prim <= n]] <- "primary"
    type[ix[ix <= n]] <- "main"
    left <- character(nrow(splits))
    side <- splits[, 2L]
    for (i in seq_along(left)) {
      left[i] <- if (side[i] == -1L)
                   paste("<", format(signif(splits[i, 4L], digits)))
                 else if (side[i] == 1L)
                   paste(">=", format(signif(splits[i, 4L], digits)))
                 else {
                   catside <- fit$csplit[splits[i, 4L], 1:side[i]]
                   paste(c("L", "-", "R")[catside], collapse = "", sep = "")
                 }
    }
    cbind(data.frame(var = rownames(splits),
                     type = type,
                     node = rep(as.integer(row.names(ff)), times = nn),
                     ix = rep(seq_len(nrow(ff)), nn),
                     left = left),
          as.data.frame(splits, row.names = F))
  }
}

Using this function, I am able to recover all the splitting variables and points:

splits <- rpart_splits(tree)[rpart_splits(tree)$type == "main", ]
splits

#   var type node ix    left count ncat    improve index adj
# 1  X2 main    1  1 < 0.565  5000   -1 0.18110662 0.565   0
# 3  X2 main    2  2 < 0.265  2814   -1 0.06358597 0.265   0
# 6  X1 main    3  5 < 0.645  2186   -1 0.07645851 0.645   0

The column var tells me the splitting variables for each non-terminal node, and the column left tells the associated splitting points. However, I do not know how to use this information to produce my desired plots.

Of course if you have any alternative strategy that do not involve the use of rpart_splits feel free to suggest it.


Solution

  • You could use the (unpublished) parttree package, which you can install from GitHub via:

    remotes::install_github("grantmcdermott/parttree")
    

    This allows:

    library(parttree)
    
    ggplot() +
      geom_parttree(data = tree, aes(fill = path)) +
      coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
      scale_fill_brewer(palette = "Pastel1", name = "Partitions") +
      theme_bw(base_size = 16) +
      labs(x = "X2", y = "X1")
    

    enter image description here

    Incidentally, this package also contains the function parttree, which returns something very similar to your rpart_splits function:

    parttree(tree)
      node         Y                        path  xmin  xmax  ymin  ymax
    1    4 0.7556079   X2 < 0.565 --> X2 < 0.265  -Inf 0.265  -Inf   Inf
    2    5 1.3087679  X2 < 0.565 --> X2 >= 0.265 0.265 0.565  -Inf   Inf
    3    6 1.8681143  X2 >= 0.565 --> X1 < 0.645 0.565   Inf  -Inf 0.645
    4    7 2.4993361 X2 >= 0.565 --> X1 >= 0.645 0.565   Inf 0.645   Inf