pythonscipyhierarchical-clustering

In python hierarchical clustering by pairwise distances, how can I cut on specific distances and get clusters and list of members of each cluster?


I have pairwise distances data like this:

distances = {

('DN1357_i2', 'DN1357_i5'): 1.0,

('DN1357_i2', 'DN10172_i1'): 28.0,

('DN1357_i2', 'DN1357_i1'): 8.0,

('DN1357_i5', 'DN1357_i1'): 2.0,

('DN1357_i5', 'DN10172_i1'): 34.0,

('DN1357_i1', 'DN10172_i1'): 38.0,
}

So I have 4 objects, I clustered these objects using this code lines:

keys = [sorted(k) for k in obj_distances.keys()]

values = obj_distances.values()

sorted_keys, distances = zip(*sorted(zip(keys, values)))

Z = linkage(distances)

labels = sorted(set([key[0] for key in sorted_keys] + [sorted_keys[-1][-1]]))

dendro = dendrogram(Z, labels=labels)

It gives me a dendrogram. What is the code to get clusters and name of objects in each cluster, (if I cut the dendrogram in distance 2)?


Solution

  • You can use the scipy function cut_tree, here's an example for your data:

    from scipy.cluster.hierarchy import cut_tree, dendrogram, linkage
    
    obj_distances = {
        ('DN1357_i2', 'DN1357_i5'): 1.0,
        ('DN1357_i2', 'DN10172_i1'): 28.0,
        ('DN1357_i2', 'DN1357_i1'): 8.0,
        ('DN1357_i5', 'DN1357_i1'): 2.0,
        ('DN1357_i5', 'DN10172_i1'): 34.0,
        ('DN1357_i1', 'DN10172_i1'): 38.0,
    }
    
    keys = [sorted(k) for k in obj_distances.keys()]
    values = obj_distances.values()
    sorted_keys, distances = zip(*sorted(zip(keys, values)))
    
    Z = linkage(distances)
    
    labels = sorted(set([key[0] for key in sorted_keys] + [sorted_keys[-1][-1]]))
    dendro = dendrogram(Z, labels=labels)
    
    members = dendro['ivl']
    clusters = cut_tree(Z, height=2)
    cluster_ids = [c[0] for c in clusters]
    cluster_ids.sort(reverse=True)
    
    for k in range(max(cluster_ids) + 1):
        print(f"Cluster {k}")
        for i, c in enumerate(cluster_ids):
            if c == k:
                print(f"{members[i]}")
    
        print('\n')
    
    

    For cutting the tree at a height of 2, the output is:

    Cluster 0
    DN10172_i1
    
    
    Cluster 1
    DN1357_i1
    
    
    Cluster 2
    DN1357_i2
    DN1357_i5