pythonscikit-learnmean-shift

How to get the center of specific cluster in sklearn.cluster.MeanShift


I have a trained MeanShift object (ms). It has a simple list of centers. How to determine the label that a center belongs to? I am aware about labels_, but I do not see a connection between labels_ and cluster_centers_.

print(ms.cluster_centers_)

[[  40.7177164   -73.99183542]
 [  33.44943805 -112.00213969]
 [  33.44638027 -111.90188756]
 ..., 
 [  46.7323875  -117.0001651 ]
 [  29.6899563   -95.8996757 ]
 [  31.3787916   -95.3213317 ]]

Solution

  • The dimension of labels is the dimension of your original dataset. It gives the index of the corresponding cluster. So the associated cluster center for an entry i in the original data is cluster_centers_[labels_[i]].

    You can see in the example from sklearn that they are looping on the number of unique labels, and using labels == k to select all the data with that label (X[labels_ == k]): https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py