pythondataframepython-polars

Polars Shuffle And Split DataFrame With Grouping


I am using polars for all preprocessing and feature engineering. I want to shuffle the data before performing a train/valid/test split.

A training 'example' consists of multiple rows. The number of rows per example varies. Here is a simple contrived example (Note I am actually using a LazyFrame in my code):

pl.DataFrame({
  "example_id": [1, 1, 2, 2, 2, 3, 3, 3, 4, 4],
  "other_col": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
})
┌────────────┬───────────┐
│ example_id ┆ other_col │
│ ---        ┆ ---       │
│ i64        ┆ i64       │
╞════════════╪═══════════╡
│ 1          ┆ 1         │
│ 1          ┆ 2         │
│ 2          ┆ 3         │
│ 2          ┆ 4         │
│ 2          ┆ 5         │
│ 3          ┆ 6         │
│ 3          ┆ 7         │
│ 3          ┆ 8         │
│ 4          ┆ 9         │
│ 4          ┆ 10        │
└────────────┴───────────┘

I want to shuffle 'over' the example_id column, while keeping the examples grouped together. Producing a result something like this:

┌────────────┬───────────┐
│ example_id ┆ other_col │
│ ---        ┆ ---       │
│ i64        ┆ i64       │
╞════════════╪═══════════╡
│ 2          ┆ 3         │
│ 2          ┆ 4         │
│ 2          ┆ 5         │
│ 1          ┆ 1         │
│ 1          ┆ 2         │
│ 4          ┆ 9         │
│ 4          ┆ 10        │
│ 3          ┆ 6         │
│ 3          ┆ 7         │
│ 3          ┆ 8         │
└────────────┴───────────┘

I then want to split the data fractionally, for example 0.6, 0.2, 0.2 for training, validation and testing respectively, but do this based on 'whole examples' rather than just row wise.

Is there a clean way to do this in polars without having to convert the example_id to an array, shuffling it, splitting into sublists, then reselecting from the original dataframe?


Solution

  • There must be a far cleaner way of achieving this, hopefully someone can improve on this. Also it requires collecting the dataframe which is not ideal. Either way, it seems to work for now. Thanks @jqurious for the pointer.

    1. Grab the unique example_ids, shuffle them and add a row count.
    example_ids = (
      example_df
      .select("example_id")
      .unique()
      .sample(fraction=1, shuffle=True)
      .with_row_index()
    )
    
    1. Split the unique ids into subsets using the row count.
    # assume we'll test on remaining data
    train_frac = 0.6
    valid_frac = 0.2
    
    train_ids = example_ids.filter(
      pl.col("index") < pl.col("index").max() * train_frac
    )
    
    valid_ids = example_ids.filter(
        pl.col("index").is_between(
            pl.col("index").max() * train_frac,
            pl.col("index").max() * (train_frac + valid_frac),
        )
    )
    test_ids = example_ids.filter(
        pl.col("index") > pl.col("index").max() * (train_frac + valid_frac)
    )
    
    1. Join each subset back to the example_df and drop the row_nr
    train_df = example_df.join(train_ids, on="example_id").drop("index")
    valid_df = example_df.join(valid_ids, on="example_id").drop("index")
    test_df = example_df.join(test_ids, on="example_id").drop("index")
    

    This will produce 3 dataframe, something like this

    ┌────────────┬───────────┐
    │ example_id ┆ other_col │
    │ ---        ┆ ---       │
    │ i64        ┆ i64       │
    ╞════════════╪═══════════╡
    │ 1          ┆ 1         │
    │ 1          ┆ 2         │
    │ 3          ┆ 6         │
    │ 3          ┆ 7         │
    │ 3          ┆ 8         │
    └────────────┴───────────┘
    
    ┌────────────┬───────────┐
    │ example_id ┆ other_col │
    │ ---        ┆ ---       │
    │ i64        ┆ i64       │
    ╞════════════╪═══════════╡
    │ 2          ┆ 3         │
    │ 2          ┆ 4         │
    │ 2          ┆ 5         │
    └────────────┴───────────┘
    
    ┌────────────┬───────────┐
    │ example_id ┆ other_col │
    │ ---        ┆ ---       │
    │ i64        ┆ i64       │
    ╞════════════╪═══════════╡
    │ 4          ┆ 9         │
    │ 4          ┆ 10        │
    └────────────┴───────────┘