I would like return rows of a dataset where groups satisfy two conditions. That is,
I am getting hung up on the latter!
Suppose I have this dataframe:
import polars as pl
df = pl.from_repr("""
┌──────┬──────────┬─────────┐
│ item ┆ location ┆ store │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════════╪═════════╡
│ 0 ┆ new york ┆ store 1 │
│ 0 ┆ boston ┆ store 1 │
│ 1 ┆ boston ┆ store 1 │
│ 1 ┆ boston ┆ store 2 │
│ 0 ┆ ohio ┆ store 1 │
│ 0 ┆ ohio ┆ store 3 │
└──────┴──────────┴─────────┘
""")
I can check for #1 easily using filter, n_unique(), and over as follows:
df.filter(
(pl.col("store").n_unique() > 1).over("item", "location")
)
┌──────┬──────────┬─────────┐
│ item ┆ location ┆ store │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════════╪═════════╡
│ 1 ┆ boston ┆ store 1 │
│ 1 ┆ boston ┆ store 2 │
│ 0 ┆ ohio ┆ store 1 │
│ 0 ┆ ohio ┆ store 3 │
└──────┴──────────┴─────────┘
However, I would like to add to the filter to return rows where the corresponding group contains "store 3".
I tried this (and variations of it):
df.filter(
(pl.col("store").n_unique() > 1).over("item", "location")
& (pl.col("store").is_in(["store 3"]))
)
┌──────┬──────────┬─────────┐
│ item ┆ location ┆ store │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════════╪═════════╡
│ 0 ┆ ohio ┆ store 3 │
└──────┴──────────┴─────────┘
But this is only returning a single row. And actually the result makes some sense to me but I was hoping to return the whole group. For example,
┌──────┬──────────┬─────────┐
│ item ┆ location ┆ store │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════════╪═════════╡
│ 0 ┆ ohio ┆ store 1 │
│ 0 ┆ ohio ┆ store 3 │
└──────┴──────────┴─────────┘
I'm quite sure I'm missing something obvious. Any help is appreciated. Loving Polars so far without a doubt!
(pl.col(...) == value).any().over(group) can be used to test if a value exists within a group.
You can "and" the multiple conditions and run them through .over()
(one & two).over(group)
df.filter(
(
(pl.col("store").n_unique() > 1) &
(pl.col('store') == 'store 3').any()
)
.over('item', 'location')
)
Using pl.all_horizontal() is also another option.
df.filter(
pl.all_horizontal(
pl.col("store").n_unique() > 1,
(pl.col('store') == 'store 3').any()
)
.over('item', 'location')
)
shape: (2, 3)
┌──────┬──────────┬─────────┐
│ item ┆ location ┆ store │
│ --- ┆ --- ┆ --- │
│ i64 ┆ str ┆ str │
╞══════╪══════════╪═════════╡
│ 0 ┆ ohio ┆ store 1 │
│ 0 ┆ ohio ┆ store 3 │
└──────┴──────────┴─────────┘