I want to plot a partition of a two-dimensional covariate space constructed by recursive binary splitting. To be more precise, I would like to write a function that replicates the following graph (taken from Elements of Statistical Learning, pag. 306):
Displayed above is a two-dimensional covariate space and a partition obtained by recursive binary splitting the space using axis-aligned splits (what is also called a CART algorithm). What I want to implement is a function that takes the output of the rpart
function and generates such plot.
It follows some example code:
## Generating data.
set.seed(1975)
n <- 5000
p <- 2
X <- matrix(sample(seq(0, 1, by = 0.01), n * p, replace = TRUE), ncol = p)
Y <- X[, 1] + 2 * X[, 2] + rnorm(n)
## Building tree.
tree <- rpart(Y ~ ., data = data.frame(Y, X), method = "anova", control = rpart.control(cp = 0, maxdepth = 2))
Navigating SO I found this function:
rpart_splits <- function(fit, digits = getOption("digits")) {
splits <- fit$splits
if (!is.null(splits)) {
ff <- fit$frame
is.leaf <- ff$var == "<leaf>"
n <- nrow(splits)
nn <- ff$ncompete + ff$nsurrogate + !is.leaf
ix <- cumsum(c(1L, nn))
ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
type <- rep.int("surrogate", n)
type[ix_prim[ix_prim <= n]] <- "primary"
type[ix[ix <= n]] <- "main"
left <- character(nrow(splits))
side <- splits[, 2L]
for (i in seq_along(left)) {
left[i] <- if (side[i] == -1L)
paste("<", format(signif(splits[i, 4L], digits)))
else if (side[i] == 1L)
paste(">=", format(signif(splits[i, 4L], digits)))
else {
catside <- fit$csplit[splits[i, 4L], 1:side[i]]
paste(c("L", "-", "R")[catside], collapse = "", sep = "")
}
}
cbind(data.frame(var = rownames(splits),
type = type,
node = rep(as.integer(row.names(ff)), times = nn),
ix = rep(seq_len(nrow(ff)), nn),
left = left),
as.data.frame(splits, row.names = F))
}
}
Using this function, I am able to recover all the splitting variables and points:
splits <- rpart_splits(tree)[rpart_splits(tree)$type == "main", ]
splits
# var type node ix left count ncat improve index adj
# 1 X2 main 1 1 < 0.565 5000 -1 0.18110662 0.565 0
# 3 X2 main 2 2 < 0.265 2814 -1 0.06358597 0.265 0
# 6 X1 main 3 5 < 0.645 2186 -1 0.07645851 0.645 0
The column var
tells me the splitting variables for each non-terminal node, and the column left
tells the associated splitting points. However, I do not know how to use this information to produce my desired plots.
Of course if you have any alternative strategy that do not involve the use of rpart_splits
feel free to suggest it.
You could use the (unpublished) parttree
package, which you can install from GitHub via:
remotes::install_github("grantmcdermott/parttree")
This allows:
library(parttree)
ggplot() +
geom_parttree(data = tree, aes(fill = path)) +
coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
scale_fill_brewer(palette = "Pastel1", name = "Partitions") +
theme_bw(base_size = 16) +
labs(x = "X2", y = "X1")
Incidentally, this package also contains the function parttree
, which returns something very similar to your
rpart_splits
function:
parttree(tree)
node Y path xmin xmax ymin ymax
1 4 0.7556079 X2 < 0.565 --> X2 < 0.265 -Inf 0.265 -Inf Inf
2 5 1.3087679 X2 < 0.565 --> X2 >= 0.265 0.265 0.565 -Inf Inf
3 6 1.8681143 X2 >= 0.565 --> X1 < 0.645 0.565 Inf -Inf 0.645
4 7 2.4993361 X2 >= 0.565 --> X1 >= 0.645 0.565 Inf 0.645 Inf