pythonpython-polars

Efficient way to get several subsets of list elements?


I have a DataFrame like this:

import polars as pl

df = pl.DataFrame(
    {
        "grp": ["a", "b"],
        "val": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
    }
)
df
shape: (2, 2)
┌─────┬──────────────┐
│ grp ┆ val          │
│ --- ┆ ---          │
│ str ┆ list[i64]    │
╞═════╪══════════════╡
│ a   ┆ [1, 2, … 5]  │
│ b   ┆ [1, 2, … 10] │
└─────┴──────────────┘

I want to select elements in the val column based on this pattern:

From those three selections, keep unique elements.

This means that when there are 6 or fewer values (as in the first row) then all values are returned, but otherwise (as in the second row) only a subset of 6 values will be returned.

Therefore, the desired output would look like this:

shape: (2, 2)
┌─────┬─────────────────────┐
│ grp ┆ val                 │
│ --- ┆ ---                 │
│ str ┆ list[i64]           │
╞═════╪═════════════════════╡
│ a   ┆ [1, 2, 3, 4, 5]     │ 
│ b   ┆ [1, 2, 4, 7, 9, 10] │  # <<<< 4 and 7 are the two randomly selected values in the "middle" set
└─────┴─────────────────────┘

To select the first two and last two values, I can use list.head() and list.tail(). For the random pick in the remaining values, I thought I could do list.set_difference() to remove the first and last two values, and then list.sample(). However, list.sample() fails because in the first row, there's only one value left after removing the first and last two, and I ask for two values:

(
    df.select(
        head=pl.col("val").list.head(2),

        middle=pl.col("val")
        .list.set_difference(pl.col("val").list.head(2))
        .list.set_difference(pl.col("val").list.tail(2))
        .list.sample(2, seed=1234),
        
        tail=pl.col("val").list.tail(2),
    ).select(concat=pl.concat_list(["head", "middle", "tail"]).list.unique())
)
ShapeError: cannot take a larger sample than the total population when `with_replacement=false`

and I don't want a sample with replacement.

What would be the best way to do this with Polars?


Solution

  • # increase repr defaults
    pl.Config(fmt_table_cell_list_len=12, fmt_str_lengths=100)
    

    You could use sample to just shuffle the list and then slice/head the result.

    df.select(
        pl.col("val").list.head(pl.col("val").list.len() - 2).list.slice(2)
          .list.sample(fraction=1, shuffle=True)
          .list.head(2)
    )
    
    shape: (2, 1)
    ┌───────────┐
    │ val       │
    │ ---       │
    │ list[i64] │
    ╞═══════════╡
    │ [3]       │
    │ [7, 4]    │
    └───────────┘
    

    when/then could then be used to choose the original list or the sample.

    df.with_columns(
        pl.when(pl.col("val").list.len() > 5)
          .then(
              pl.concat_list(
                  pl.col("val").list.head(2),
                  pl.col("val").list.head(pl.col("val").list.len() - 2).list.slice(2)
                    .list.sample(fraction=1, shuffle=True)
                    .list.head(2),
                  pl.col("val").list.tail(2),
              )
          )
          .otherwise("val")
          .alias("sample")
    )
    
    shape: (2, 3)
    ┌─────┬─────────────────────────────────┬─────────────────────┐
    │ grp ┆ val                             ┆ sample              │
    │ --- ┆ ---                             ┆ ---                 │
    │ str ┆ list[i64]                       ┆ list[i64]           │
    ╞═════╪═════════════════════════════════╪═════════════════════╡
    │ a   ┆ [1, 2, 3, 4, 5]                 ┆ [1, 2, 3, 4, 5]     │
    │ b   ┆ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ┆ [1, 2, 6, 5, 9, 10] │
    └─────┴─────────────────────────────────┴─────────────────────┘
    

    Alternatively, sample does accept expressions for n - so you could sample 0 items if the list is not large enough.

    df.with_columns(
        pl.when(pl.col("val").list.len() > 5)
          .then(
              pl.concat_list(
                  pl.col("val").list.head(2),
                  pl.col("val").list.head(pl.col("val").list.len() - 2).list.slice(2)
                    .list.sample(n=pl.when(pl.col("val").list.len() > 5).then(2).otherwise(0)),
                  pl.col("val").list.tail(2)
              )
          )
          .otherwise(pl.col("val"))
          .list.unique()
          .alias("sample")
    )
    
    shape: (2, 3)
    ┌─────┬─────────────────────────────────┬─────────────────────┐
    │ grp ┆ val                             ┆ sample              │
    │ --- ┆ ---                             ┆ ---                 │
    │ str ┆ list[i64]                       ┆ list[i64]           │
    ╞═════╪═════════════════════════════════╪═════════════════════╡
    │ a   ┆ [1, 2, 3, 4, 5]                 ┆ [1, 2, 3, 4, 5]     │
    │ b   ┆ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ┆ [1, 2, 5, 7, 9, 10] │
    └─────┴─────────────────────────────────┴─────────────────────┘