Can someone explain what is the use of predict()
method in kmeans implementation of scikit learn? The official documentation states its use as:
Predict the closest cluster each sample in X belongs to.
But I can get the cluster number/label for each sample of input set X by training the model on fit_transform()
method also. So what is the use of predict()
method? Is it supposed to point out closest cluster for the unseen data? If yes, then how do you handle a new data point if you perform dimensionality reduction measure such as SVD?
Here's a similar question but I still don't think it really helps.
what is the use of predict() method? Is it supposed to point out closest cluster for the unseen data?
Yes, exactly.
then how do you handle a new data point if you perform dimensionality reduction measure such as SVD?
You apply the same dimensionality reduction method to the unseen data before passing it to .predict()
. Here is a typical workflow:
# prerequisites:
# x_train: training data
# x_test: "unseen" testing data
# km: initialized `KMeans()` instance
# dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)
# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)
# ...
# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)
# ...
Actually, methods such as fit_transform
and fit_predict
are there for convenience. y = km.fit_predict(x)
is equivalent to y = km.fit(x).predict(x)
.
I think it's easier to see what's going on if we write the fitting part as follows:
# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)
km.fit(x_dr)
y = km.predict(x_dr)
Except for the call to .fit()
the models used equally during fitting and with unseen data.
Summary:
.fit()
is to train the model with data..predict()
or .transform()
is to apply a trained model to data..fit_predict()
or .fit_transform()
for convenience.