I am trying to understand how a rpart
tree grows, So I am growing a tree step by step, and I am finding a strange (?) behavior. Let me show this by means of an example: I will use the titanic data set and the rpart package to grow the tree. Here is the code:
#######################
###load packages
#######################
library(tidyverse)
library(rpart)
library(partykit)
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
library(titanic)
library(ggparty)
z0<-titanic_train
###formating factors
for(i in c("Survived","Pclass","Sex","Embarked")){
z0[,i]<- factor(z0[,i])
}
z0[,"Survived"]<-factor(z0[,"Survived"],labels = c("No","Yes"))
#First split
fitTT1<-rpart(Survived ~ Pclass+Age+Sex+Fare, data = z0,
control=rpart.control(maxdepth=1))
fitTT1 %>% as.party() %>%autoplot
#2nd split
fitTT2<-rpart(Survived ~ Pclass+Age+Sex+Fare, data = z0,
control=rpart.control(maxdepth=2))
fitTT2 %>% as.party() %>%autoplot
#3rd split
fitTT3<-rpart(Survived ~ Pclass+Age+Sex+Fare, data = z0,
control=rpart.control(maxdepth=3))
fitTT3 %>% as.party() %>%autoplot
Created on 2023-12-07 with reprex v2.0.2
As you can see, the first split is based on Sex, and in the second split, only the node associated with males is split but not the node of females. In the third split, now the female node is split in three nodes.
My question is why in the second round of splits the female node is not split in the same way as the male node since variable Pclass can make a split in this node as well. How is this parameter maxdepth working? Why is the split of females not considered as a split of depth 2? Is there a way to obtain a tree with four leafs: males-Age>6.5, male-Age<6.5, female-Pclass3, female-Pcalss1,2?
The function rpart()
has many options. Here is a setting that does what you want. It supresses cross-validation, and does not require a positive gain per split.
fitTT2 <- rpart(
Survived ~ Pclass + Age + Sex + Fare,
data = z0,
maxdepth = 2,
method = "class",
parms = list(split = "information"),
xval = 0,
cp = -1
)
fitTT2 |>
as.party() |>
autoplot()