Update: Support for direct list comparison was added to Polars.
df.filter(pl.concat_list(pl.all()) == [2, 5, 8])
Say I have this:
df = polars.DataFrame(dict(
j=[1,2,3],
k=[4,5,6],
l=[7,8,9],
))
shape: (3, 3)
┌─────┬─────┬─────┐
│ j ┆ k ┆ l │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1 ┆ 4 ┆ 7 │
│ 2 ┆ 5 ┆ 8 │
│ 3 ┆ 6 ┆ 9 │
└─────┴─────┴─────┘
I can filter for a particular row doing it one column at at time, i.e.:
df = df.filter(
(polars.col('j') == 2) &
(polars.col('k') == 5) &
(polars.col('l') == 8)
)
shape: (1, 3)
┌─────┬─────┬─────┐
│ j ┆ k ┆ l │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 2 ┆ 5 ┆ 8 │
└─────┴─────┴─────┘
I'd like to compare to the list instead though (so I can avoid listing each column and to accommodate variable column DataFrame
s), e.g. something like:
df = df.filter(
polars.concat_list(polars.all()) == [2, 5, 8]
)
...
exceptions.ArrowErrorException: NotYetImplemented("Casting from Int64 to LargeList(Field { name: \"item\", data_type: Int64, is_nullable: true, metadata: {} }) not supported")
Any ideas why the above is throwing the exception?
I can build the expression manually:
df = df.filter(
functools.reduce(lambda a, e: a & e, (polars.col(c) == v for c, v in zip(df.columns, [2, 5, 8])))
)
but I was hoping there's a way to compare lists directly - e.g. as if I had this DataFrame
originally:
df = polars.DataFrame(dict(j=[
[1,4,7],
[2,5,8],
[3,6,9],
]))
shape: (3, 1)
┌───────────┐
│ j │
│ --- │
│ list[i64] │
╞═══════════╡
│ [1, 4, 7] │
│ [2, 5, 8] │
│ [3, 6, 9] │
└───────────┘
and wanted to find the row which matches [2, 5, 8]
. Any hints?
You can pass multiple conditions to .all_horizontal()
instead of functools.reduce
For a list column, you can compare the values at each index with .list.get()
:
df.filter(
pl.all_horizontal(
pl.col("j").list.get(n) == row[n]
for row in [[2, 5, 8]]
for n in range(len(row))
)
)
shape: (1, 1)
┌───────────┐
│ j │
│ --- │
│ list[i64] │
╞═══════════╡
│ [2, 5, 8] │
└───────────┘
I'm not sure why this doesn't work:
df.filter(pl.col("j") == pl.lit([[2, 5, 8]]))
shape: (0, 1)
┌───────────┐
│ j │
│ --- │
│ list[i64] │
╞═══════════╡
└───────────┘
For regular columns, you could modify your example:
df.filter(
pl.all_horizontal(
pl.col(col) == value
for col, value in dict(zip(df.columns, [2, 5, 8])).items()
)
)