Say I am using the titanic dataset, with the variable age only:
import pandas as pd
data = pd.read_csv('https://www.openml.org/data/get_csv/16826755/phpMYEkMl')[["age", "survived"]]
data = data.replace('?', np.nan)
data = data.fillna(0)
print(data)
the result:
age survived
0 29 1
1 0.9167 1
2 2 0
3 30 0
4 25 0
... ... ...
1304 14.5 0
1305 0 0
1306 26.5 0
1307 27 0
1308 29 0
[1309 rows x 2 columns]
Now I train a decision tree to predict survival from age:
from sklearn.tree import DecisionTreeClassifier
tree_model = DecisionTreeClassifier(max_depth=3)
tree_model.fit(data['age'].to_frame(),data["survived"])
And if I print the structure of the tree:
from sklearn import tree
print(tree.export_text(tree_model))
I obtain:
|--- feature_0 <= 0.08
| |--- class: 0
|--- feature_0 > 0.08
| |--- feature_0 <= 8.50
| | |--- feature_0 <= 1.50
| | | |--- class: 1
| | |--- feature_0 > 1.50
| | | |--- class: 1
| |--- feature_0 > 8.50
| | |--- feature_0 <= 60.25
| | | |--- class: 0
| | |--- feature_0 > 60.25
| | | |--- class: 0
These means that the final division for every node is:
0-0.08 ; 0.08-1.50; 1.50-8.50 ; 8.50-60; >60
My question is, how can I capture those limits in an array that looks like this:
[-np.inf, 0.08, 1.5, 8.5, 60, np.inf]
Thank you!
The decision classifier, in this case tree_model
has an attribute called tree_
which allows access to low level attributes.
print(tree_model.tree_.threshold)
array([ 0.08335, -2. , 8.5 , 1.5 , -2. , -2. ,
60.25 , -2. , -2. ])
print(tree_model.tree_.feature)
array([ 0, -2, 0, 0, -2, -2, 0, -2, -2], dtype=int64)
The arrays feature
and threshold
only apply to split nodes. The values for leaf nodes in these arrays are therefore arbitrary.
To get the division/threshold of a feature, you can filter the threshold using the feature
array.
threshold = tree_model.tree_.threshold
feature = tree_model.tree_.feature
feature_threshold = threshold[feature == 0]
thresholds = sorted(feature_threshold)
print(thresholds)
[0.08335000276565552, 1.5, 8.5, 60.25]
To have np.inf
, you need to add it yourself.
thresholds = [-np.inf] + thresholds + [np.inf]
print(thresholds)
[-inf, 0.08335000276565552, 1.5, 8.5, 60.25, inf]
Reference: Understanding the decision tree structure.