I am trying to find common elements in a column of list wrt a reference cell. I could accomplish it with a small dataset but I face two problems. The speed is excruciatingly slow even for 25 rows of sample data (20.7 s ± 52 ms per loop), and I unable to find a faster implementation through map_batches
which can use parallelization unlike map_elements
that works on a single thread.
The version that I have working right now is as follows:
import polars as pl
import numpy as np
df = pl.DataFrame({'animal': ['goat','tiger','goat','tiger','lion','goat','tiger','lion'], 'food': ['grass','rabbit','carrots','deer','zebra','water','water','water']})
dl = df.group_by('animal', maintain_order=True).all()
dl
shape: (3, 2)
┌────────┬───────────────────────────────┐
│ animal ┆ food │
│ --- ┆ --- │
│ str ┆ list[str] │
╞════════╪═══════════════════════════════╡
│ goat ┆ ["grass", "carrots", "water"] │
│ tiger ┆ ["rabbit", "deer", "water"] │
│ lion ┆ ["zebra", "water"] │
└────────┴───────────────────────────────┘
refn = dl['food'][1].to_numpy()
dl = dl.with_columns(
pl.col('food').map_elements(lambda x: np.intersect1d(refn,x.to_numpy()), return_dtype=pl.List(pl.String))
)
dl
shape: (3, 2)
┌────────┬─────────────────────────────┐
│ animal ┆ food │
│ --- ┆ --- │
│ str ┆ list[str] │
╞════════╪═════════════════════════════╡
│ goat ┆ ["water"] │
│ tiger ┆ ["deer", "rabbit", "water"] │
│ lion ┆ ["water"] │
└────────┴─────────────────────────────┘
Any help will be greatly appreciated. TIA.
Update: Polars has since added dedicated set operations for lists.
(df.group_by('animal', maintain_order=True)
.all()
.with_columns(
pl.col("food").list.set_intersection(pl.col("food").get(1))
)
)
pl.col("food").get(1)
is the expression equivalent the refn
assignment.
Original answer
If you are using python/numpy/sets for memberships in polars you are most of the time on the wrong path. This counts actually for using map_elements
in general.
Let's start of where you left. You took the 1
st index of column "food"
which has dtype List
. So that gives us a Series
assigned to refn
.
dl = df.group_by('animal', maintain_order=True).all()
refn = dl['food'][1]
refn
shape: (3,)
Series: 'food' [str]
[
"rabbit"
"deer"
"water"
]
This solution does an O(n)
search in every elemnt, so this is likely not the fastest.
We simply filter
the original df
with membership of refn
and the aggregate the result again.
(df.filter(pl.col("food").is_in(refn))
.group_by("animal").all()
)
A better option, is using a semi
join. A semi
join filters the DataFrame
by membership in another DataFrame
, sounds like something we want!
(df.join(refn.to_frame(), on="food", how="semi")
.group_by("animal").all()
)