pythonpandasdataframescikit-learnk-fold

Separate pandas dataframe using sklearn's KFold


I had obtained the index of training set and testing set with code below.

df = pandas.read_pickle(filepath + filename)
kf = KFold(n_splits = n_splits, shuffle = shuffle, random_state = 
randomState)

result = next(kf.split(df), None)

#train can be accessed with result[0]
#test can be accessed with result[1]

I wonder if there is any faster way to separate them into 2 dataframe respectively with the row indexes I retrieved.


Solution

  • You need DataFrame.iloc for select rows by positions:

    Sample:

    np.random.seed(100)
    df = pd.DataFrame(np.random.random((10,5)), columns=list('ABCDE'))
    df.index = df.index * 10
    print (df)
               A         B         C         D         E
    0   0.543405  0.278369  0.424518  0.844776  0.004719
    10  0.121569  0.670749  0.825853  0.136707  0.575093
    20  0.891322  0.209202  0.185328  0.108377  0.219697
    30  0.978624  0.811683  0.171941  0.816225  0.274074
    40  0.431704  0.940030  0.817649  0.336112  0.175410
    50  0.372832  0.005689  0.252426  0.795663  0.015255
    60  0.598843  0.603805  0.105148  0.381943  0.036476
    70  0.890412  0.980921  0.059942  0.890546  0.576901
    80  0.742480  0.630184  0.581842  0.020439  0.210027
    90  0.544685  0.769115  0.250695  0.285896  0.852395
    

    from sklearn.model_selection import KFold
    
    #added some parameters
    kf = KFold(n_splits = 5, shuffle = True, random_state = 2)
    result = next(kf.split(df), None)
    print (result)
    (array([0, 2, 3, 5, 6, 7, 8, 9]), array([1, 4]))
    
    train = df.iloc[result[0]]
    test =  df.iloc[result[1]]
    
    print (train)
               A         B         C         D         E
    0   0.543405  0.278369  0.424518  0.844776  0.004719
    20  0.891322  0.209202  0.185328  0.108377  0.219697
    30  0.978624  0.811683  0.171941  0.816225  0.274074
    50  0.372832  0.005689  0.252426  0.795663  0.015255
    60  0.598843  0.603805  0.105148  0.381943  0.036476
    70  0.890412  0.980921  0.059942  0.890546  0.576901
    80  0.742480  0.630184  0.581842  0.020439  0.210027
    90  0.544685  0.769115  0.250695  0.285896  0.852395
    
    print (test)
               A         B         C         D         E
    10  0.121569  0.670749  0.825853  0.136707  0.575093
    40  0.431704  0.940030  0.817649  0.336112  0.175410