pandasscikit-multilearn

Exact stratification of a pandas DataFrame using mutiple columns


I have found a solution to stratify a pandas DataFrame over multiple columns. However, the obtained stratification is not exact, in the sense that there are labels that are more represented in certain folds than in others. I would like to find another way to stratify the DataFrame that would give more exact results, in the sense that every fold would have the same size and that every label of each class would be present in equal numbers (up to a difference of 1) in each fold.

I am using:

scikit-learn==1.4.1.post1 
scikit-multilearn==0.2.0 
numpy==1.26.4 
pandas==2.2.1

Inspired by the answer from grofte in sklearn-train-test-split-on-pandas-stratify-by-multiple-columns, I have first generated a random Dataframe with 300 rows and 3 columns against which I want to stratify and have written a stratification function (iterative_split) that uses IterativeStratification from skmultilearn. I then use this function to split the DataFrame in 10 folds, hopefully having an equipartition of the combination of labels in each fold:

import pandas as pd
import numpy as np
import random
from skmultilearn.model_selection import IterativeStratification

random.seed(1)
lcol1 = ['a', 'b', 'c']
lcol2 = ['e', 'f']
lcol3 = ['g', 'h']
example = pd.DataFrame({
    'col1': random.choices(lcol1, k=300, weights=[1,1,3]),
    'col2': random.choices(lcol2, k=300),
    'col3': random.choices(lcol3, k=300)})
example.head()

def iterative_split(df, folds, n_splits, stratify_columns):
    """Custom iterative train test split which
    'maintains balanced representation with respect
    to order-th label combinations.'

    From https://madewithml.com/courses/mlops/splitting/#stratified-split
    """
    # One-hot encode the stratify columns and concatenate them
    one_hot_cols = [pd.get_dummies(df[col]) for col in stratify_columns]
    one_hot_cols = pd.concat(one_hot_cols, axis=1).to_numpy()
    stratifier = IterativeStratification(
        n_splits=n_splits, order=len(stratify_columns), sample_distribution_per_fold=folds)
    folds = []
    for indices in stratifier.split(df.to_numpy(), one_hot_cols):
        folds.append(df.iloc[indices[1]])
    return folds

n_splits = 10
folds = iterative_split(example, [1./n_splits,]*n_splits, n_splits, ['col1', 'col2', 'col3'])

The only output of the previous code (given by: example.head()) gives the first lines of the generated DataFrame::


    col1    col2    col3
0   a   e   h
1   c   e   h
2   c   e   g
3   b   e   h
4   c   f   h

To visualize the split results, I am using a helper print function:

def print_results(parent, folds):

    # For each conbination of labels, prints the number of rows for each fold
    # having this combination
    total_errors = 0
    n_splits = len(folds)
    print("query   : #rows      : #rows per fold\n")
    for col1 in parent.col1.unique():
        for col2 in parent.col2.unique():
            for col3 in parent.col3.unique():
                df = parent.query("col1==@col1 and col2==@col2 and col3==@col3")
                len_query = len(df)
                print(f"{col1}, {col2}, {col3} : total = {len_query} : per fold =", end = ' ')
                for fold in folds:
                    df0 = fold.query("col1==@col1 and col2==@col2 and col3==@col3")
                    len_query_fold = len(df0)
                    if abs(len_query_fold-len_query/n_splits) >= 1:
                        total_errors += 1
                    print(f"{len_query_fold} -", end= ' ')
                print("")

    # Prints the statistics and the number of stratification errors
    expected_total_length = len(parent)
    total_length = 0
    total_mismatches = 0
    print("\nlengths of folds : ", end = ' ')
    for fold in folds:
        len_fold = len(fold)
        print(len_fold, end=' ')
        total_length += len_fold
        if abs(len_fold-expected_total_length/n_splits) >= 1:
            total_mismatches += 1
    print(f"\nExpected total_length = {expected_total_length}")
    print(f"Effective total_length = {total_length}")

    print(f"total number of stratification errors: {total_errors}")
    print(f"total number of mismatched fold sizes : {total_mismatches}")

print_results(example, folds)

I obtain the following output:

query   : #rows      : #rows per fold

a, e, h : total = 12 : per fold = 1 - 1 - 1 - 1 - 2 - 1 - 1 - 2 - 1 - 1 - 
a, e, g : total = 15 : per fold = 2 - 2 - 2 - 2 - 1 - 1 - 1 - 1 - 1 - 2 - 
a, f, h : total = 18 : per fold = 1 - 2 - 1 - 2 - 1 - 2 - 3 - 2 - 3 - 1 - 
a, f, g : total = 15 : per fold = 2 - 1 - 2 - 1 - 2 - 2 - 1 - 1 - 1 - 2 - 
c, e, h : total = 46 : per fold = 4 - 4 - 5 - 5 - 4 - 5 - 5 - 5 - 4 - 5 - 
c, e, g : total = 33 : per fold = 3 - 3 - 3 - 3 - 4 - 4 - 4 - 3 - 3 - 3 - 
c, f, h : total = 52 : per fold = 8 - 7 - 5 - 5 - 6 - 5 - 4 - 3 - 4 - 5 - 
c, f, g : total = 45 : per fold = 5 - 5 - 5 - 5 - 4 - 3 - 4 - 5 - 5 - 4 - 
b, e, h : total = 17 : per fold = 1 - 1 - 2 - 1 - 2 - 2 - 2 - 2 - 2 - 2 - 
b, e, g : total = 15 : per fold = 3 - 2 - 1 - 2 - 1 - 1 - 1 - 1 - 2 - 1 - 
b, f, h : total = 21 : per fold = 2 - 2 - 2 - 3 - 2 - 2 - 2 - 2 - 2 - 2 - 
b, f, g : total = 11 : per fold = 2 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 

lengths of folds :  34 31 30 31 30 29 29 28 29 29 
Expected total_length = 300
Effective total_length = 300
total number of stratification errors: 9
total number of mismatched fold sizes : 8

We see here the two problems with this stratification:

The first 12 output lines give the number of Dataframe rows having each combination of labels: for the first line, which corresponds to the combination of 'a' in the first column, 'e' in the second column and 'h' in the third column, there are in total 12 DataFrame rows having this combination. As there are ten folds, an exact stratification will assign either 1 or 2 Dataframe rows (called hereafter 'subjects') to each fold, which is indeed the case for the first output line.

However (and that is the first problem I have), if we look at the 7th line (combination: 'c', 'f', 'h'), as there are in total 52 subjects for this combination, we would expect either 5 or 6 subjects with this combination for each fold. However, the first fold has 8 subjects, the second has 7, and the 7th, 8th, and 9th folds have either 3 or 4 subjects. These are classified as stratification errors. There are 9 such stratification errors for all the folds and all the combinations.

Second, an exact stratification would assign to each fold the same number of subjects when possible (here, we expect 300/10 = 30 subjects). Here, the lengths of the folds vary from 28 to 34. We classify these as 'mismatched fold sizes'; there are here 8 mismatched fold sizes out of 10.

Do you know of another way to exactly stratify a Dataframe according to the labels of several columns that would improve the number of stratification errors?


Solution

  • A solution that seems to work is to randomize the DataFrame row, sort the DataFrame according to the columns against which we want to stratify and take one out of n_splits row for each fold. I first create the DataFrame:

    random.seed(1)
    lcol1 = ['a', 'b', 'c']
    lcol2 = ['e', 'f']
    lcol3 = ['g', 'h']
    example = pd.DataFrame({
        'col1': random.choices(lcol1, k=300, weights=[1,1,3]),
        'col2': random.choices(lcol2, k=300),
        'col3': random.choices(lcol3, k=300)})
    

    Here is the function that creates 10 folds stratified for each combination of column labels:

    def iterative_split_through_sorting(df, n_splits, stratify_columns, random_state):
       """Custom iterative train test split which
       maintains balanced representation.
       """
    
       # Dataframe random row shuffle + sorting according to stratify_columns
       sorted = df.sample(frac=1, random_state=random_state).sort_values(stratify_columns)
       # for each fold, we take one row every n_splits rows
       folds = [sorted.iloc[i::n_splits, :] for i in range(n_splits)]
       
       return folds
    

    By using the helper print function,

    print_results(example, folds)
    

    we obtain as output:

    query   : #rows      : #rows per fold
    
    a, e, h : total = 12 : per fold = 1 - 1 - 1 - 1 - 1 - 2 - 2 - 1 - 1 - 1 - 
    a, e, g : total = 15 : per fold = 2 - 2 - 2 - 2 - 2 - 1 - 1 - 1 - 1 - 1 - 
    a, f, h : total = 18 : per fold = 1 - 1 - 2 - 2 - 2 - 2 - 2 - 2 - 2 - 2 - 
    a, f, g : total = 15 : per fold = 2 - 2 - 1 - 1 - 1 - 1 - 1 - 2 - 2 - 2 - 
    c, e, h : total = 46 : per fold = 5 - 5 - 5 - 4 - 4 - 4 - 4 - 5 - 5 - 5 - 
    c, e, g : total = 33 : per fold = 3 - 3 - 3 - 3 - 4 - 4 - 4 - 3 - 3 - 3 - 
    c, f, h : total = 52 : per fold = 5 - 5 - 5 - 5 - 5 - 5 - 5 - 5 - 6 - 6 - 
    c, f, g : total = 45 : per fold = 4 - 4 - 4 - 5 - 5 - 5 - 5 - 5 - 4 - 4 - 
    b, e, h : total = 17 : per fold = 2 - 2 - 1 - 1 - 1 - 2 - 2 - 2 - 2 - 2 - 
    b, e, g : total = 15 : per fold = 2 - 2 - 2 - 2 - 2 - 1 - 1 - 1 - 1 - 1 - 
    b, f, h : total = 21 : per fold = 2 - 2 - 2 - 3 - 2 - 2 - 2 - 2 - 2 - 2 - 
    b, f, g : total = 11 : per fold = 1 - 1 - 2 - 1 - 1 - 1 - 1 - 1 - 1 - 1 - 
    
    lengths of folds :  30 30 30 30 30 30 30 30 30 30 
    Expected total_length = 300
    Effective total_length = 300
    total number of stratification errors: 0
    total number of mismatched fold sizes : 0
    

    This time, there is no more stratification error nor mismatched fold size. A snippet of the content of each fold is given by running:

    folds[0].head()
    

    which outputs:

    col1    col2    col3
    154 a   e   g
    234 a   e   g
    196 a   e   h
    100 a   f   g
    181 a   f   g
    

    By construction, all identical combinations (for example 'a','e','g') are put together in each fold. We can thus improve this function by shuffling the rows of each fold (note that we also shuffle the fold order):

    def iterative_split_through_sorting_shuffle(df, n_splits, stratify_columns, random_state):
        """Custom iterative train test split which
        maintains balanced representation.
        """
    
        # Dataframe random row shuffle + sorting according to stratify_columns
        sorted = df.sample(frac=1, random_state=random_state).sort_values(stratify_columns)
        # for each fold, we take one row every n_splits rows
        folds = [sorted.iloc[i::n_splits, :] for i in range(n_splits)]
        # Further shuffling
        folds = [fold.sample(frac=1, random_state=random_state) for fold in folds]
        random.Random(random_state).shuffle(folds)
    
        return folds