parallel-processingapplypython-polars

PyPolars: Speed up apply function to find common elements


I am trying to find common elements in a column of list wrt a reference cell. I could accomplish it with a small dataset but I face two problems. The speed is excruciatingly slow even for 25 rows of sample data (20.7 s ± 52 ms per loop), and I unable to find a faster implementation through map_batches which can use parallelization unlike map_elements that works on a single thread.

The version that I have working right now is as follows:

import polars as pl
import numpy as np

df = pl.DataFrame({'animal': ['goat','tiger','goat','tiger','lion','goat','tiger','lion'], 'food': ['grass','rabbit','carrots','deer','zebra','water','water','water']})
dl = df.group_by('animal', maintain_order=True).all()
dl
shape: (3, 2)
┌────────┬───────────────────────────────┐
│ animal ┆ food                          │
│ ---    ┆ ---                           │
│ str    ┆ list[str]                     │
╞════════╪═══════════════════════════════╡
│ goat   ┆ ["grass", "carrots", "water"] │
│ tiger  ┆ ["rabbit", "deer", "water"]   │
│ lion   ┆ ["zebra", "water"]            │
└────────┴───────────────────────────────┘
refn = dl['food'][1].to_numpy()

dl = dl.with_columns(
    pl.col('food').map_elements(lambda x: np.intersect1d(refn,x.to_numpy()), return_dtype=pl.List(pl.String))
)

dl
shape: (3, 2)
┌────────┬─────────────────────────────┐
│ animal ┆ food                        │
│ ---    ┆ ---                         │
│ str    ┆ list[str]                   │
╞════════╪═════════════════════════════╡
│ goat   ┆ ["water"]                   │
│ tiger  ┆ ["deer", "rabbit", "water"] │
│ lion   ┆ ["water"]                   │
└────────┴─────────────────────────────┘

Any help will be greatly appreciated. TIA.


Solution

  • Update: Polars has since added dedicated set operations for lists.

    e.g. .list.set_intersection()

    (df.group_by('animal', maintain_order=True)
       .all()
       .with_columns(
           pl.col("food").list.set_intersection(pl.col("food").get(1)) 
       )
    )
    

    pl.col("food").get(1) is the expression equivalent the refn assignment.


    Original answer

    If you are using python/numpy/sets for memberships in polars you are most of the time on the wrong path. This counts actually for using map_elements in general.

    Let's start of where you left. You took the 1st index of column "food" which has dtype List. So that gives us a Series assigned to refn.

    dl = df.group_by('animal', maintain_order=True).all()
    refn = dl['food'][1]
    refn
    
    shape: (3,)
    Series: 'food' [str]
    [
        "rabbit"
        "deer"
        "water"
    ]
    

    Option 1 (slower)

    This solution does an O(n) search in every elemnt, so this is likely not the fastest.

    We simply filter the original df with membership of refn and the aggregate the result again.

    (df.filter(pl.col("food").is_in(refn))
       .group_by("animal").all()
    )
    

    Option 2 (fastest)

    A better option, is using a semi join. A semi join filters the DataFrame by membership in another DataFrame, sounds like something we want!

    (df.join(refn.to_frame(), on="food", how="semi")
       .group_by("animal").all()
    )