pythonmachine-learningscikit-learnxgboostdecision-tree

How to visualize an XGBoost tree from GridSearchCV output?


I am using XGBRegressor to fit the model using gridsearchcv. I want to visulaize the trees.

Here is the link I followed ( If duplicate) how to plot a decision tree from gridsearchcv?

xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)
folds = 5
grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4, verbose=3 )
model=grid.fit(X_train, y_train)

Approach 1:

 dot_data = tree.export_graphviz(model.best_estimator_, out_file=None, 
        filled=True, rounded=True, feature_names=X_train.columns)
 dot_data

 Error: NotFittedError: This XGBRegressor instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

Approach 2:

tree.export_graphviz(best_clf, out_file='tree.dot',feature_names=X_train.columns,leaves_parallel=True)
subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])

Same error.


Solution

  • scikit-learn's tree.export_graphviz will not work here, because your best_estimator_ is not a single tree, but a whole ensemble of trees.

    Here is how you can do it using XGBoost's own plot_tree and the Boston housing data:

    from xgboost import XGBRegressor, plot_tree
    from sklearn.model_selection import GridSearchCV
    from sklearn.datasets import load_boston
    import matplotlib.pyplot as plt
    
    X, y = load_boston(return_X_y=True)
    
    params = {'learning_rate':[0.1, 0.5], 'n_estimators':[5, 10]} # dummy, for demonstration only
    
    xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)
    grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4)
    
    grid.fit(X, y)
    

    Our best estimator is:

    grid.best_estimator_
    # result (details may be different due to randomness):
    XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
                 colsample_bynode=1, colsample_bytree=1, gamma=0,
                 importance_type='gain', learning_rate=0.5, max_delta_step=0,
                 max_depth=3, min_child_weight=1, missing=None, n_estimators=10,
                 n_jobs=1, nthread=1, objective='reg:linear', random_state=0,
                 reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
                 silent=True, subsample=1, verbosity=1)
    

    Having done that, and utilizing the answer from this SO thread to plot, say, tree #4:

    fig, ax = plt.subplots(figsize=(30, 30))
    plot_tree(grid.best_estimator_, num_trees=4, ax=ax)
    plt.show()
    

    enter image description here

    Similarly, for tree #1:

    fig, ax = plt.subplots(figsize=(30, 30))
    plot_tree(grid.best_estimator_, num_trees=1, ax=ax)
    plt.show()
    

    enter image description here