pythondataframepython-polars

Non-equi join in polars


If you come from the future, hopefully this PR has already been merged.

If you don't come from the future, hopefully this answer solves your problem.

I want to solve my problem only with polars (which I am no expert, but I can follow what is going on), before just copy-pasting the DuckDB integration suggested above and compare the results in my real data.

I have a list of events (name and timestamp), and a list of time windows. I want to count how many of each event occur in each time window.

I feel like I am close to getting something that works correctly, but I have been stuck for a couple of hours now:

import polars as pl

events = {
    "name": ["a", "b", "a", "b", "a", "c", "b", "a", "b", "a", "b", "a", "b", "a", "b", "a", "b", "a", "b"],
    "time": [0.0, 1.0, 1.5, 2.0, 2.25, 2.26, 2.45, 2.5, 3.0, 3.4, 3.5, 3.6, 3.65, 3.7, 3.8, 4.0, 4.5, 5.0, 6.0],
}

windows = {
    "start_time": [1.0, 2.0, 3.0, 4.0],
    "stop_time": [3.5, 2.5, 3.7, 5.0],
}

events_df = pl.DataFrame(events).sort("time").with_row_index()
windows_df = (
    pl.DataFrame(windows)
    .sort("start_time")
    .join_asof(events_df, left_on="start_time", right_on="time", strategy="forward")
    .drop("name", "time")
    .rename({"index": "first_index"})
    .sort("stop_time")
    .join_asof(events_df, left_on="stop_time", right_on="time", strategy="backward")
    .drop("name", "time")
    .rename({"index": "last_index"})
)

print(windows_df)
"""
shape: (4, 4)
┌────────────┬───────────┬─────────────┬────────────┐
│ start_time ┆ stop_time ┆ first_index ┆ last_index │
│ ---        ┆ ---       ┆ ---         ┆ ---        │
│ f64        ┆ f64       ┆ u32         ┆ u32        │
╞════════════╪═══════════╪═════════════╪════════════╡
│ 2.0        ┆ 2.5       ┆ 3           ┆ 7          │
│ 1.0        ┆ 3.5       ┆ 1           ┆ 10         │
│ 3.0        ┆ 3.7       ┆ 8           ┆ 13         │
│ 4.0        ┆ 5.0       ┆ 15          ┆ 17         │
└────────────┴───────────┴─────────────┴────────────┘
"""

So far, for each time window, I can get the index of the first and last events that I care about. Now I "just" need to count how many of these are of each type. Can I get some help on how to do this?

The output I am looking for should look like:

shape: (4, 5)
┌────────────┬───────────┬─────┬─────┬─────┐
│ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
│ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
│ f64        ┆ f64       ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═══════════╪═════╪═════╪═════╡
│ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
│ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
│ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
│ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
└────────────┴───────────┴─────┴─────┴─────┘

I feel like using something like int_ranges(), gather(), and explode() can get me a dataframe with each time window and all it's corresponding events. Finally, something like group_by(), count(), and pivot() can get me to the dataframe I want. But I have been struggling with this for a while.


Solution

  • update join_where() was released in version 1.7.0:

    (
        windows_df
        .join_where(
            events_df,
            pl.col.time >= pl.col.start_time,
            pl.col.time <= pl.col.stop_time,
        )
        .sort("name", "start_time")
        .pivot(on="name", index=["start_time","stop_time"], aggregate_function="len", values="time")
        .fill_null(0)
    )
    
    ┌────────────┬───────────┬─────┬─────┬─────┐
    │ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
    │ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
    │ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
    ╞════════════╪═══════════╪═════╪═════╪═════╡
    │ 1.0        ┆ 3.5       ┆ 4   ┆ 5   ┆ 1   │
    │ 2.0        ┆ 2.5       ┆ 2   ┆ 2   ┆ 1   │
    │ 3.0        ┆ 3.7       ┆ 3   ┆ 3   ┆ 0   │
    │ 4.0        ┆ 5.0       ┆ 2   ┆ 1   ┆ 0   │
    └────────────┴───────────┴─────┴─────┴─────┘
    

    previous Not sure if it will be more performant, but you can transform your windows_df into desirable output with:

    (
        windows_df
        .with_columns(index = pl.int_ranges(pl.col.first_index, pl.col.last_index, dtype=pl.UInt32))
        .explode("index")
        .join(events_df, on="index", how="inner")
        .pivot(on="name", index=["start_time","stop_time"], aggregate_function="len", values="index")
        .fill_null(0)
    )
    
    ┌────────────┬───────────┬─────┬─────┬─────┐
    │ start_time ┆ stop_time ┆ a   ┆ b   ┆ c   │
    │ ---        ┆ ---       ┆ --- ┆ --- ┆ --- │
    │ f64        ┆ f64       ┆ u32 ┆ u32 ┆ u32 │
    ╞════════════╪═══════════╪═════╪═════╪═════╡
    │ 2.0        ┆ 2.5       ┆ 1   ┆ 2   ┆ 1   │
    │ 1.0        ┆ 3.5       ┆ 4   ┆ 4   ┆ 1   │
    │ 3.0        ┆ 3.7       ┆ 2   ┆ 3   ┆ 0   │
    │ 4.0        ┆ 5.0       ┆ 1   ┆ 1   ┆ 0   │
    └────────────┴───────────┴─────┴─────┴─────┘