I have a trained MeanShift object (ms
). It has a simple list of centers. How to determine the label that a center belongs to?
I am aware about labels_
, but I do not see a connection between labels_
and cluster_centers_
.
print(ms.cluster_centers_)
[[ 40.7177164 -73.99183542]
[ 33.44943805 -112.00213969]
[ 33.44638027 -111.90188756]
...,
[ 46.7323875 -117.0001651 ]
[ 29.6899563 -95.8996757 ]
[ 31.3787916 -95.3213317 ]]
The dimension of labels
is the dimension of your original dataset. It gives the index of the corresponding cluster. So the associated cluster center for an entry i
in the original data is cluster_centers_[labels_[i]]
.
You can see in the example from sklearn that they are looping on the number of unique labels, and using labels == k
to select all the data with that label (X[labels_ == k]
):
https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py