From the example of scikit-learn using scipy, (only changing the truncate_mode from 'level' to 'lastp'),
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering
def plot_dendrogram(model, **kwargs):
# Create linkage matrix and then plot the dendrogram
# create the counts of samples under each node
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)
# Plot the corresponding dendrogram
dendrogram(linkage_matrix, **kwargs)
iris = load_iris()
X = iris.data
# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
model = model.fit(X)
plt.title("Hierarchical Clustering Dendrogram")
# plot the top three levels of the dendrogram
plot_dendrogram(model, truncate_mode="lastp", p=10)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()
I get,
Now, how can I get the elements that belong to each of the clusters for that specific truncation? I would like to know the indices of 21, 17, 12 etc elements that constitute each of those clusters.
This is bit hacky but you can get the full dendrogram with custom labels
then obtain it from the return value of dendrogram
.
# Fit model with desired number of clusters
model = AgglomerativeClustering(10, compute_full_tree=True, compute_distances=True).fit(X)
# Return dendrogram output
def plot_dendrogram(model, **kwargs):
# Create linkage matrix and then plot the dendrogram
# create the counts of samples under each node
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)
# Plot the corresponding dendrogram
d = dendrogram(linkage_matrix, no_labels=False, **kwargs)
return d
# Plot dendrogram
d = plot_dendrogram(clustering, leaf_rotation=90, truncate_mode='lastp', p=10)
# Obtain full dendrogram output without plotting, provide clustering labels
d2 = plot_dendrogram(clustering, leaf_rotation=90, no_plot=True, labels=clustering.labels_)
# d2 includes sorted cluster labels
leaf_labels = np.array(d2['ivl'])
agg_labels = np.unique(leaf_labels, return_index=True)[1]
labels = [leaf_labels[idx] for idx in sorted(agg_labels)]
print(labels)
>>> [10, 4, 6, 1, 8, 11, 3, 7, 2, 5, 9, 0]