I'm trying to improve my classification results by doing clustering and use the clustered data as another feature (or use it alone instead of all other features - not sure yet).
So let's say that I'm using unsupervised algorithm - GMM:
gmm = GaussianMixture(n_components=4, random_state=RSEED)
gmm.fit(X_train)
pred_labels = gmm.predict(X_test)
I trained the model with training data and predicted the clusters by the test data.
Now I want to use a classifier (KNN for example) and use the clustered data within it. So I tried:
#define the model and parameters
knn = KNeighborsClassifier()
parameters = {'n_neighbors':[3,5,7],
'leaf_size':[1,3,5],
'algorithm':['auto', 'kd_tree'],
'n_jobs':[-1]}
#Fit the model
model_gmm_knn = GridSearchCV(knn, param_grid=parameters)
model_gmm_knn.fit(pred_labels.reshape(-1, 1),Y_train)
model_gmm_knn.best_params_
But I'm getting:
ValueError: Found input variables with inconsistent numbers of samples: [418, 891]
Train and Test are not with same dimension. So how can I implement such approach?
Your method is not correct - you are attempting to use as a single feature the cluster labels of your test data pred_labels
, in order to fit a classifier with your training labels Y_train
. Even in the huge coincidental case that the dimensions of these datasets were the same (hence not giving a dimension mismatch error, as here), this is conceptually wrong and does not actually make any sense.
What you actually want to do is:
All in all, and assuming that your X_train
and X_test
are pandas dataframes, here is the procedure:
import pandas as pd
gmm.fit(X_train)
cluster_train = gmm.predict(X_train)
cluster_test = gmm.predict(X_test)
X_train['cluster_label'] = pd.Series(cluster_train, index=X_train.index)
X_test['cluster_label'] = pd.Series(cluster_test, index=X_test.index)
model_gmm_knn.fit(X_train, Y_train)
Notice that you should not fit your clustering model with your test data - only with your training ones, otherwise you have data leakage similar to the one encountered when using the test set for feature selection, and your results will be both invalid and misleading .