pythonscikit-learndecision-treepruning

Prune unnecessary leaves in sklearn DecisionTreeClassifier


I use sklearn.tree.DecisionTreeClassifier to build a decision tree. With the optimal parameter settings, I get a tree that has unnecessary leaves (see example picture below - I do not need probabilities, so the leaf nodes marked with red are a unnecessary split)

Tree

Is there any third-party library for pruning these unnecessary nodes? Or a code snippet? I could write one, but I can't really imagine that I am the first person with this problem...

Code to replicate:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS: I have tried multiple keyword searches and am kind of surprised to find nothing - is there really no post-pruning in general in sklearn?

PPS: In response to the possible duplicate: While the suggested question might help me when coding the pruning algorithm myself, it answers a different question - I want to get rid of leaves that do not change the final decision, while the other question wants a minimum threshold for splitting nodes.

PPPS: The tree shown is an example to show my problem. I am aware of the fact that the parameter settings to create the tree are suboptimal. I am not asking about optimizing this specific tree, I need to do post-pruning to get rid of leaves that might be helpful if one needs class probabilities, but are not helpful if one is only interested in the most likely class.


Solution

  • Using ncfirth's link, I was able to modify the code there so that it fits to my problem:

    from sklearn.tree._tree import TREE_LEAF
    
    def is_leaf(inner_tree, index):
        # Check whether node is leaf node
        return (inner_tree.children_left[index] == TREE_LEAF and 
                inner_tree.children_right[index] == TREE_LEAF)
    
    def prune_index(inner_tree, decisions, index=0):
        # Start pruning from the bottom - if we start from the top, we might miss
        # nodes that become leaves during pruning.
        # Do not use this directly - use prune_duplicate_leaves instead.
        if not is_leaf(inner_tree, inner_tree.children_left[index]):
            prune_index(inner_tree, decisions, inner_tree.children_left[index])
        if not is_leaf(inner_tree, inner_tree.children_right[index]):
            prune_index(inner_tree, decisions, inner_tree.children_right[index])
    
        # Prune children if both children are leaves now and make the same decision:     
        if (is_leaf(inner_tree, inner_tree.children_left[index]) and
            is_leaf(inner_tree, inner_tree.children_right[index]) and
            (decisions[index] == decisions[inner_tree.children_left[index]]) and 
            (decisions[index] == decisions[inner_tree.children_right[index]])):
            # turn node into a leaf by "unlinking" its children
            inner_tree.children_left[index] = TREE_LEAF
            inner_tree.children_right[index] = TREE_LEAF
            ##print("Pruned {}".format(index))
    
    def prune_duplicate_leaves(mdl):
        # Remove leaves if both 
        decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
        prune_index(mdl.tree_, decisions)
    

    Using this on a DecisionTreeClassifier clf:

    prune_duplicate_leaves(clf)
    

    Edit: Fixed a bug for more complex trees