I was wondering if hierarchical classifications are supported by the sciki-learn library. I am dealing with the 3 classes divided by 6 subclasses each, such as:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
X = np.random.randn(5, 1)
number, rows, cols = 5, 3, 6
y = np.zeros((number, rows, cols), dtype=int)
for n in range(number):
for row in range(rows):
col = np.random.randint(cols)
y[n, row, col] = 1
tree = DecisionTreeClassifier()
tree.fit(X, y)
but find error:
ValueError: Found array with dim 3. DecisionTreeClassifier expected <= 2.
You can use hiclass.
pip install hiclass
Train the model:
from sklearn.tree import DecisionTreeClassifier
from hiclass.MultiLabelLocalClassifierPerNode import MultiLabelLocalClassifierPerNode
tree = DecisionTreeClassifier()
classifier = MultiLabelLocalClassifierPerNode(local_classifier=tree)
classifier.fit(X, y)
Test and measure precision:
hiclass.metrics import precision
predictions = classifier.predict(y_test)
p = precision(y_test, predictions)
print(p)
Refer to the hiclass paper for complementary information.