pythonscikit-learn

Stratification fails in train_test_split


Please consider the following code:


import pandas as pd
from sklearn.model_selection import train_test_split


# step 1
ids = list(range(1000))
label = 500 * [1.0] + 500 * [0.0]
df = pd.DataFrame({"id": ids, "label": label})

# step 2
train_p = 0.8
val_p = 0.1
test_p = 0.1

# step 3
n_train = int(len(df) * train_p)
n_val = int(len(df) * val_p)
n_test = len(df) - n_train - n_val

print("* Step 3")
print("train:", n_train)
print("val:", n_val)
print("test:", n_test)
print()

# step 4
train_ids, test_ids = train_test_split(df["id"], stratify=df.label, test_size=n_test, random_state=42)

# step 5
print("* Step 5. First split")
print( df.loc[df.id.isin(train_ids), "label"].value_counts() )
print( df.loc[df.id.isin(test_ids), "label"].value_counts() )
print()

# step 6
train_ids, val_ids = train_test_split(train_ids, stratify=df.loc[df.id.isin(train_ids), "label"], test_size=n_val, random_state=42)

# step 7
train_df = df[df["id"].isin(train_ids)]
val_df = df[df["id"].isin(val_ids)]
test_df = df[df["id"].isin(test_ids)]

# step 8
print("* Step 8. Final split")
print("train:", train_df["label"].value_counts())
print("val:", val_df["label"].value_counts())
print("test:", test_df["label"].value_counts())


with output:

* Step 3
train: 800
val: 100
test: 100

* Step 5. First split
label
1.0    450
0.0    450
Name: count, dtype: int64
label
1.0    50
0.0    50
Name: count, dtype: int64

* Step 8. Final split
train: label
0.0    404
1.0    396
Name: count, dtype: int64
val: label
1.0    54
0.0    46
Name: count, dtype: int64
test: label
1.0    50
0.0    50
Name: count, dtype: int64
  1. Create a Dataframe with 1000 elements perfectly balanced between class 1 and 0 (positive and negative);
  2. Define the ratio of examples that should go into the training, validation and test partitions. I would like 800 examples in the training split, 100 examples in each one of the other two.
  3. Compute the sizes of the three partitions and print their values.
  4. Perform the first split to get the test set, stratified on label.
  5. Print label stats of the first split. The two partitions are still balanced.
  6. Perform the second splitting into training and validation, stratified on label.
  7. Select examples
  8. Print label stats.

As you can see the second split at step 6 does not produce a balanced split (stats printed at step 8). After the first split, the examples (output at step 5) are still balanced and it would be possible to perform a second split keeping a perfect class balance.

What am I doing wrong?


Solution

  • You're only providing the IDs in the second split, which may not be properly linked to the entire row's data (with the label column) for correct stratification. I think it doesn't work as intended because the label distribution is not being maintained across the full set of rows.

    For example, in your first split you have a dataframe and used df["id"] and df.label like shown below.

    # step 4
    train_ids, test_ids = train_test_split(df["id"], stratify=df.label, test_size=n_test, random_state=42)
    

    So, in your second split, you can do the same like this:

    train_df = df[df.id.isin(train_ids)]
    train_ids, val_ids = train_test_split(train_df["id"], stratify=train_df["label"], test_size=n_val, random_state=42)
    

    and it will work perfectly !