python-3.xscikit-learntraining-dataimbalanced-dataoversampling

using sklearn.train_test_split for Imbalanced data


I have a very imbalanced dataset. I used sklearn.train_test_split function to extract the train dataset. Now I want to oversample the train dataset, so I used to count number of type1(my data set has 2 categories and types(type1 and tupe2) but approximately all of my train data are type1. So I cant oversample.

Previously I used to split train test datasets with my written code. In that code 0.8 of all type1 data and 0.8 of all type2 data were in the train dataset.

How I can use this method with train_test_split function or other spliting methods in sklearn?

*I should just use sklearn or my own written methods.


Solution

  • You're looking for stratification. Why?

    There's a parameter stratify in method train_test_split to which you can give the labels list e.g. :

    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        stratify=y, 
                                                        test_size=0.2)
    

    There's also StratifiedShuffleSplit.