pythonpython-polars

Find intersection of dates in grouped polars dataframe


Consider the following pl.DataFrame:

import polars as pl

data = {
    "symbol": ["AAPL"] * 5 + ["GOOGL"] * 3 + ["MSFT"] * 4,
    "date": [
        "2023-01-01",
        "2023-01-02",
        "2023-01-03",
        "2023-01-04",
        "2023-01-05",  # AAPL has 5 dates
        "2023-01-01",
        "2023-01-02",
        "2023-01-03",  # GOOGL has 3 dates
        "2023-01-01",
        "2023-01-02",
        "2023-01-03",
        "2023-01-04",  # MSFT has 4 dates
    ],
}

df = pl.DataFrame(data)

with pl.Config(tbl_rows=-1):
    print(df)

shape: (12, 2)
┌────────┬────────────┐
│ symbol ┆ date       │
│ ---    ┆ ---        │
│ str    ┆ str        │
╞════════╪════════════╡
│ AAPL   ┆ 2023-01-01 │
│ AAPL   ┆ 2023-01-02 │
│ AAPL   ┆ 2023-01-03 │
│ AAPL   ┆ 2023-01-04 │
│ AAPL   ┆ 2023-01-05 │
│ GOOGL  ┆ 2023-01-01 │
│ GOOGL  ┆ 2023-01-02 │
│ GOOGL  ┆ 2023-01-03 │
│ MSFT   ┆ 2023-01-01 │
│ MSFT   ┆ 2023-01-02 │
│ MSFT   ┆ 2023-01-03 │
│ MSFT   ┆ 2023-01-04 │
└────────┴────────────┘

I need to make each group's dates (grouped_by symbol) consistent accross all groups. Therefore, I need to identify the common dates across all groups (probably using join) and subsequently filter the dataframe accordingly.

It might be related to Find intersection of columns from different polars dataframes.

I am looking for a generalized solution. In the above example the resulting pl.DataFrame should look as follows:

shape: (9, 2)
┌────────┬────────────┐
│ symbol ┆ date       │
│ ---    ┆ ---        │
│ str    ┆ str        │
╞════════╪════════════╡
│ AAPL   ┆ 2023-01-01 │
│ AAPL   ┆ 2023-01-02 │
│ AAPL   ┆ 2023-01-03 │
│ GOOGL  ┆ 2023-01-01 │
│ GOOGL  ┆ 2023-01-02 │
│ GOOGL  ┆ 2023-01-03 │
│ MSFT   ┆ 2023-01-01 │
│ MSFT   ┆ 2023-01-02 │
│ MSFT   ┆ 2023-01-03 │
└────────┴────────────┘

Solution

  • You could count the number of unique (n_unique) symbols over date and filter the rows that have all symbols:

    df.filter(pl.col('symbol').n_unique().over('date')
                .eq(pl.col('symbol').n_unique()))
    

    Output:

    ┌────────┬────────────┐
    │ symbol ┆ date       │
    │ ---    ┆ ---        │
    │ str    ┆ str        │
    ╞════════╪════════════╡
    │ AAPL   ┆ 2023-01-01 │
    │ AAPL   ┆ 2023-01-02 │
    │ AAPL   ┆ 2023-01-03 │
    │ GOOGL  ┆ 2023-01-01 │
    │ GOOGL  ┆ 2023-01-02 │
    │ GOOGL  ┆ 2023-01-03 │
    │ MSFT   ┆ 2023-01-01 │
    │ MSFT   ┆ 2023-01-02 │
    │ MSFT   ┆ 2023-01-03 │
    └────────┴────────────┘
    

    Intermediates:

    ┌────────┬────────────┬───────────────────┬───────────────────┐
    │ symbol ┆ date       ┆ nunique_over_date ┆ eq_symbol_nunique │
    │ ---    ┆ ---        ┆ ---               ┆ ---               │
    │ str    ┆ str        ┆ u32               ┆ bool              │
    ╞════════╪════════════╪═══════════════════╪═══════════════════╡
    │ AAPL   ┆ 2023-01-01 ┆ 3                 ┆ true              │
    │ AAPL   ┆ 2023-01-02 ┆ 3                 ┆ true              │
    │ AAPL   ┆ 2023-01-03 ┆ 3                 ┆ true              │
    │ AAPL   ┆ 2023-01-04 ┆ 2                 ┆ false             │
    │ AAPL   ┆ 2023-01-05 ┆ 1                 ┆ false             │
    │ GOOGL  ┆ 2023-01-01 ┆ 3                 ┆ true              │
    │ GOOGL  ┆ 2023-01-02 ┆ 3                 ┆ true              │
    │ GOOGL  ┆ 2023-01-03 ┆ 3                 ┆ true              │
    │ MSFT   ┆ 2023-01-01 ┆ 3                 ┆ true              │
    │ MSFT   ┆ 2023-01-02 ┆ 3                 ┆ true              │
    │ MSFT   ┆ 2023-01-03 ┆ 3                 ┆ true              │
    │ MSFT   ┆ 2023-01-04 ┆ 2                 ┆ false             │
    └────────┴────────────┴───────────────────┴───────────────────┘