I have a HistGradientBoostingClassifier
model and I want to plot one or more of its decision trees, nevertheless I can't manage to find a native function to do it, I can access the Tree predictor objects and thus it's nodes, but in order to plot it into the sklearn.tree.plot_tree
function it needs to be a DecisionTree
type object
I tried this:
from sklearn.tree import plot_tree
plot_tree(RF_90._predictors[0][0])
getting this error:
InvalidParameterError: The 'decision_tree' parameter of plot_tree must be an instance of 'sklearn.tree._classes.DecisionTreeClassifier' or an instance of 'sklearn.tree._classes.DecisionTreeRegressor'. Got <sklearn.ensemble._hist_gradient_boosting.predictor.TreePredictor object at 0x7f676ebf0310> instead.
Note: RF_90
is the HistGradientBoostingClassifier
fitted model
In order to visualize trees generated by HistGradientboostingClassifier this function worked for me:
def visualize_tree(tree, feature_names, class_names):
dot = graphviz.Digraph()
def add_nodes_edges(dot, nodes, node_id):
node = nodes[node_id]
if node['is_leaf']:
value = node['value']
dot.node(str(node_id), f"Predict: {value}")
else:
feature = feature_names[node['feature_idx']]
threshold = node['bin_threshold']
dot.node(str(node_id), f"{feature} <= {threshold:.2f}")
left_child = node['left']
right_child = node['right']
dot.edge(str(node_id), str(left_child), "True")
dot.edge(str(node_id), str(right_child), "False")
add_nodes_edges(dot, nodes, left_child)
add_nodes_edges(dot, nodes, right_child)
nodes = tree.__getstate__()['nodes']
add_nodes_edges(dot, nodes, 0)
return dot
# Create and visualize the tree
dot = visualize_tree(single_tree, RF_90.feature_names_in_, 1)
dot.render("hist_gb_tree") # Save to file
dot #view from jupyter
where RF_90 is the fitted model and single_tree is:
single_tree = trees_per_iteration[iteration][class_index]
where iteration = 0 and class_index = 0