pythonscikit-learnrandom-forestclass-variables

scikit-learn RandomForestClassifier list all variables of an estimator tree?


I train a RandomForestClassifier as

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification()
clf = RandomForestClassifier()
clf.fit(X,y)

where X and y are some feature vectors and labels.

Once the fit is done, I can e.g. list the depth of all trees grown for each estimator in the forest as follows:

[estimator.tree_.max_depth for estimator in clf.estimators_]

Now I would like to find out all other public variables (apart from max_depth) a tree_ within an estimator stores. So I tried:

vars(clf.estimators_[0].tree_)

but unfortunately this does not work and returns the error

TypeError: vars() argument must have __dict__ attribute

What syntax can I use to successfully list all public variables in a estimator.tree_?


Solution

  • There is no way to get this attributes automatically but the documentation of Tree class give you all attributes:

    To know more: https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html