pythonscikit-learndatasetshufflecross-validation

Stratified GroupShuffleSplit in Scikit-learn


I would like to ask if it is possible to do "Stratified GroupShuffleSplit" in scikit-learn which is in other words a combination of GroupShuffleSplit and StratifiedShuffleSplit

Here is a sample of the code I am using:

cv=GroupShuffleSplit(n_splits=n_splits,test_size=test_size,\
    train_size=train_size,random_state=random_state).split(\
    allr_sets_nor[:,:2],allr_labels,groups=allr_groups)
opt=GridSearchCV(SVC(decision_function_shape=dfs,tol=tol),\
    param_grid=param_grid,scoring=scoring,n_jobs=n_jobs,cv=cv,verbose=verbose)
opt.fit(allr_sets_nor[:,:2],allr_labels)

Here I applied the GroupShuffleSplit but I still want to add the startification according to allr_labels


Solution

  • I solved the problem by applying StratifiedShuffleSplit on the groups and then finding training and testing sets indices manually because they are linked to the groups indices (in my case each group contains 6 successive sets from 6*index to 6*index+5)

    as in the following:

    sss=StratifiedShuffleSplit(n_splits=n_splits,test_size=test_size,
        train_size=train_size,random_state=random_state).split(all_groups,all_labels) 
            # startified splitting for groups only
    
    i=0
    train_is = [np.array([],dtype=int)]*n_splits
    test_is = [np.array([],dtype=int)]*n_splits
    for train_index,test_index in sss :
            # finding the corresponding indices of reflected training and testing sets
        train_is[i]=np.hstack((train_is[i],np.concatenate([train_index*6+i for i in range(6)])))
        test_is[i]=np.hstack((test_is[i],np.concatenate([test_index*6+i for i in range(6)])))
        i=i+1
    
    cv=[(train_is[i],test_is[i]) for i in range(n_splits)]
            # constructing the final cross-validation iterable: list of 'n_splits' tuples;
            # each tuple contains two numpy arrays for training and testing indices respectively
    
    opt=GridSearchCV(SVC(decision_function_shape=dfs,tol=tol),param_grid=param_grid,
                     scoring=scoring,n_jobs=n_jobs,cv=cv,verbose=verbose)
    opt.fit(allr_sets_nor[:,:2],allr_labels)