rtreeparty

find the path/rules for each node partykit


Is it possible to find the path/rules for each node? I want to extract the rules for each node, and not just for the terminal nodes.

Example:

library(partykit)
X    <- MASS::mvrnorm(n, rep(0, p), diag(p))
y    <- as.numeric(drop(X %*% rep(1, p)) > 2)

data <- data.frame(y, X)

tree      <- rpart(y ~ .,
                   data = data,
                   control = rpart.control(cp = 0.005))
pfit      <- as.party(tree)

I can use partykit:::.list.rules.party(pfit) but this return the rules of the terminal nodes. I'm looking for the rules for each node.


Solution

  • Set the argument i = ... to specify all the node IDs for which you want the rules. With nodeids() you can extract all node IDs (by default):

    R> partykit:::.list.rules.party(pfit, i = nodeids(pfit))
                                                                              1
                                                                             "" 
                                                                              2 
                                                       "X3 < 0.650618460125409" 
                                                                              3 
                               "X3 < 0.650618460125409 & X2 < 1.62837615944647" 
                                                                              4 
       "X3 < 0.650618460125409 & X2 < 1.62837615944647 & X4 < 1.38485264813313" 
    ...