pythondataframedata-sciencepython-polars

Is there a way to cumulatively and distinctively expand list in polars


For distance, I want to accomplish conversion like below.

df = pl.DataFrame({
    "col": [["a"],  ["a", "b"],  ["c"]]
})
┌────────────┐
│ col        │
│ ---        │
│ list[str]  │
╞════════════╡
│ ["a"]      │
│ ["a", "b"] │
│ ["c"]      │
└────────────┘
↓
↓
↓
┌────────────┬─────────────────┐
│ col        ┆ col_cum         │
│ ---        ┆ ---             │
│ list[str]  ┆ list[str]       │
╞════════════╪═════════════════╡
│ ["a"]      ┆ ["a"]           │
│ ["a", "b"] ┆ ["a", "b"]      │
│ ["c"]      ┆ ["a", "b", "c"] │
└────────────┴─────────────────┘

I've tried polars.Expr.cumulative_eval(), but could not get it to work.

I can access the first element and last element in every iteration. But I want here is the result of the previous iteration i think. Could I get some help?


Solution

  • We can use the cumulative_eval expression.

    But first, let's expand your data so that we can include some other things that may be of interest.

    import polars as pl
    
    df = pl.DataFrame(
        {
            "group": [1, 1, 1, 2, 2, 2, 2],
            "var": [["a"], ["a", "b"], ["c"], ["p"], ["q", "p"], [], ["s"]],
        }
    )
    df
    
    shape: (7, 2)
    ┌───────┬────────────┐
    │ group ┆ var        │
    │ ---   ┆ ---        │
    │ i64   ┆ list[str]  │
    ╞═══════╪════════════╡
    │ 1     ┆ ["a"]      │
    │ 1     ┆ ["a", "b"] │
    │ 1     ┆ ["c"]      │
    │ 2     ┆ ["p"]      │
    │ 2     ┆ ["q", "p"] │
    │ 2     ┆ []         │
    │ 2     ┆ ["s"]      │
    └───────┴────────────┘
    

    The Algorithm

    Here's the heart of the algorithm:

    (
        df
        .with_columns(
            pl.col('var')
            .cumulative_eval(
                pl.element()
                .explode()
                .unique()
                .sort()
                .implode()
            )
            .list.drop_nulls()
            .over('group')
            .alias('cumulative')
        )
    )
    
    shape: (7, 3)
    ┌───────┬────────────┬─────────────────┐
    │ group ┆ var        ┆ cumulative      │
    │ ---   ┆ ---        ┆ ---             │
    │ i64   ┆ list[str]  ┆ list[str]       │
    ╞═══════╪════════════╪═════════════════╡
    │ 1     ┆ ["a"]      ┆ ["a"]           │
    │ 1     ┆ ["a", "b"] ┆ ["a", "b"]      │
    │ 1     ┆ ["c"]      ┆ ["a", "b", "c"] │
    │ 2     ┆ ["p"]      ┆ ["p"]           │
    │ 2     ┆ ["q", "p"] ┆ ["p", "q"]      │
    │ 2     ┆ []         ┆ ["p", "q"]      │
    │ 2     ┆ ["s"]      ┆ ["p", "q", "s"] │
    └───────┴────────────┴─────────────────┘
    

    How it works

    cumulative_eval allows us to treat a subset of rows for a column as if it was a Series itself (with the exception that we access the elements of the underlying Series using polars.element.)

    So, let's simulate what the cumulative_eval expression is doing by working work with the Series itself directly. We'll simulate what the algorithm does when cumulative_eval reaches the last row where group == 1 (the third row).

    The first major step of the algorithm is to explode the lists. explode will put each element of each list on its own row:

    (
        df
        .select(
            pl.col('var')
            .filter(pl.col('group') == 1)
            .explode()
        )
    )
    
    shape: (4, 1)
    ┌─────┐
    │ var │
    │ --- │
    │ str │
    ╞═════╡
    │ a   │
    │ a   │
    │ b   │
    │ c   │
    └─────┘
    

    In the next step, we will use unique and sort to eliminate duplicates and keep the order consistent.

    (
        df
        .select(
            pl.col('var')
            .filter(pl.col('group') == 1)
            .explode()
            .unique()
            .sort()
        )
    )
    
    shape: (3, 1)
    ┌─────┐
    │ var │
    │ --- │
    │ str │
    ╞═════╡
    │ a   │
    │ b   │
    │ c   │
    └─────┘
    

    At this point, we need only to roll up all the values into a list.

    (
        df
        .select(
            pl.col('var')
            .filter(pl.col('group') == 1)
            .explode()
            .unique()
            .sort()
            .implode()
        )
    )
    
    shape: (1, 1)
    ┌─────────────────┐
    │ var             │
    │ ---             │
    │ list[str]       │
    ╞═════════════════╡
    │ ["a", "b", "c"] │
    └─────────────────┘
    

    And that is the value that cumulative_eval returns for the third row.

    Performance

    The documentation for cumulative_eval comes with a strong warning about performance.

    Warning: This can be really slow as it can have O(n^2) complexity. Don’t use this for operations that visit all elements.

    Let's simulate some data. The code below generates about 9.5 million records, 10,000 groups, so that there are about 950 observations per group.

    import numpy as np
    from string import ascii_lowercase
    
    rng = np.random.default_rng(1)
    
    nbr_rows = 10_000_000
    df = (
        pl.DataFrame({
            'group': rng.integers(1, 10_000, size=nbr_rows),
            'sub_list': rng.integers(1, 10_000, size=nbr_rows),
            'var': rng.choice(list(ascii_lowercase), nbr_rows)
        })
        .group_by('group', 'sub_list')
        .agg(
            pl.col('var')
        )
        .drop('sub_list')
        .sort('group')
    )
    df
    
    shape: (9515737, 2)
    ┌───────┬────────────┐
    │ group ┆ var        │
    │ ---   ┆ ---        │
    │ i64   ┆ list[str]  │
    ╞═══════╪════════════╡
    │ 1     ┆ ["q", "r"] │
    │ 1     ┆ ["z"]      │
    │ 1     ┆ ["b"]      │
    │ 1     ┆ ["j"]      │
    │ ...   ┆ ...        │
    │ 9999  ┆ ["z"]      │
    │ 9999  ┆ ["e"]      │
    │ 9999  ┆ ["s"]      │
    │ 9999  ┆ ["s"]      │
    └───────┴────────────┘
    

    One my 32-core system, here's the wall-clock time:

    import time
    start = time.perf_counter()
    (
        df
        .with_columns(
            pl.col('var')
            .cumulative_eval(
                pl.element()
                .explode()
                .unique()
                .sort()
                .implode()
            )
            .list.drop_nulls()
            .over('group')
            .alias('cumulative')
        )
    )
    print(time.perf_counter() - start)
    
    shape: (9515737, 3)
    ┌───────┬────────────┬─────────────────────┐
    │ group ┆ var        ┆ cumulative          │
    │ ---   ┆ ---        ┆ ---                 │
    │ i64   ┆ list[str]  ┆ list[str]           │
    ╞═══════╪════════════╪═════════════════════╡
    │ 1     ┆ ["q", "r"] ┆ ["q", "r"]          │
    │ 1     ┆ ["z"]      ┆ ["q", "r", "z"]     │
    │ 1     ┆ ["b"]      ┆ ["b", "q", ... "z"] │
    │ 1     ┆ ["j"]      ┆ ["b", "j", ... "z"] │
    │ ...   ┆ ...        ┆ ...                 │
    │ 9999  ┆ ["z"]      ┆ ["a", "b", ... "z"] │
    │ 9999  ┆ ["e"]      ┆ ["a", "b", ... "z"] │
    │ 9999  ┆ ["s"]      ┆ ["a", "b", ... "z"] │
    │ 9999  ┆ ["s"]      ┆ ["a", "b", ... "z"] │
    └───────┴────────────┴─────────────────────┘
    >>> print(time.perf_counter() - start)
    118.46121257600134
    

    Roughly 2 minutes on my system for 9.5 million records. Depending on your system, you may get better or worse performance. But the point is that is didn't take hours to complete.

    If you need better performance, we can come up with a better-performing algorithm (or perhaps put in a feature request for a cumlist feature in Polars, which might have a complexity better than the O(n^2) complexity of cumulative_eval.)