scikit-learndata-sciencecluster-analysishierarchical-clusteringhdbscan

Scikit HDBSCAN *tree* labeling (not single-slice labeling)


BLUF: For a specific epsilon (or for HDBSCAN's 'favorite' epsilon), I can extract the mapping of my data in that epsilon's partition. But how can I see my data's full tree membership?

I've gotten a ton out of the terrific tutorial here. In scikit learn's HDBSCAN, I can use clusterer.labels to see the best epsilon's partition labels. And I can use clusterer.single_linkage_tree_.get_clusters(0.023, min_cluster_size=2) to see the an arbitrary epsilon's partition labels. I can even plot the entire dendogram using clusterer.condensed_tree_.plot(). But how do I see the dendogram's labels for individual datapoints?

For Example: It's nice that my pets' names are {Spot, Felix, Nemo, Fido, Tigger}. Or the species are {Dog, Cat, Guppy, Dog, Cat}. But I'd like one output that tells me:

Spot Dog Mammal Animal
Felix Cat Mammal Animal
Nemo Guppy Fish Animal
Fido Dog Mammal Animal
Tigger Cat Mammal Animal

With this sort of output, I could see precisely how related Spot and Felix are, instead of "Do they have the same species? Y/N?" "Do they have the same kingdom? Y/N?"


Solution

  • The clusterer.condensed_tree_ object has a number of conversion utilities, e.g. to_pandas() and to_networkx(). For this particular use case, it looks like you want to print an ancestor list for each leaf node in the condensed tree. You can accomplish this in many ways, but a pretty straightforward one is to convert the tree to a networkx graph and use the utility methods on it to extract the structure you're looking for:

    import hdbscan
    import networkx as nx
    import numpy as np
    
    # run HDBSCAN
    data = np.load('clusterable_data.npy')
    clusterer = hdbscan.HDBSCAN(min_cluster_size=15).fit(data)
    
    # convert tree to networkx graph
    tree = clusterer.condensed_tree_.to_networkx()
    assert nx.algorithms.tree.recognition.is_tree(tree)
    
    # find the root by picking an arbitrary node and walking up
    root = 0
    while True:
        try:
            root = next(tree.predecessors(root))
        except StopIteration:
            break
    
    # create the ancestor list for each data point
    all_ancestors = []
    for leaf_node in range(len(data)):
        ancestors = nx.shortest_path(tree, source=root, target=leaf_node)[::-1]
        all_ancestors.append(ancestors)
    

    Printing all_ancestors will give you something like:

    [[0, 2324, 2319, 2317, 2312, 2311, 2309],
     [1, 2319, 2317, 2312, 2311, 2309],
     [2, 2319, 2317, 2312, 2311, 2309],
     [3, 2333, 2324, 2319, 2317, 2312, 2311, 2309],
     [4, 2324, 2319, 2317, 2312, 2311, 2309],
     [5, 2334, 2332, 2324, 2319, 2317, 2312, 2311, 2309],
     ...
     [995, 2309],
     [996, 2318, 2317, 2312, 2311, 2309],
     [997, 2318, 2317, 2312, 2311, 2309],
     [998, 2318, 2317, 2312, 2311, 2309],
     [999, 2318, 2317, 2312, 2311, 2309],
     ...]
    

    The first entry in each list is the node ID (corresponding to the index of the node in in the data array), the second entry is that node's parent, and so on up to the root (which, in this case, has the ID 2309). Note that any node ID greater than the number of data items that you have is a "cluster node" (i.e. an inner node of the tree), and any lower node ID is a "data point node" (i.e. a leaf node in the tree).

    It may be slightly easier to understand this list format by sorting the nodes into their clusters, e.g. with:

    all_ancestors.sort(key=lambda path: path[1:])
    

    Printing all_ancestors now will give you something like:

    [[21, 2309],
     [126, 2309],
     [152, 2309],
     [155, 2309],
     [156, 2309],
     [172, 2309],
     ...
     [1912, 2313, 2311, 2309],
     [1982, 2313, 2311, 2309],
     [2014, 2313, 2311, 2309],
     [2028, 2313, 2311, 2309],
     [2071, 2313, 2311, 2309],
     ...
     [1577, 2337, 2314, 2310, 2309],
     [1585, 2337, 2314, 2310, 2309],
     [1591, 2337, 2314, 2310, 2309],
     [1910, 2337, 2314, 2310, 2309],
     [2188, 2337, 2314, 2310, 2309]]
    

    There are many equivalent ways to get the same result (e.g. by looping over the pandas dataframe produces by to_pandas()), but networkx is a natural choice for pretty much anything you might want to do with trees/DAGs/graphs.