pythonpython-polars

Compare polars list to python list


Update: Support for direct list comparison was added to Polars.

df.filter(pl.concat_list(pl.all()) == [2, 5, 8])

Say I have this:

df = polars.DataFrame(dict(
  j=[1,2,3],
  k=[4,5,6],
  l=[7,8,9],
  ))

shape: (3, 3)
┌─────┬─────┬─────┐
│ j   ┆ k   ┆ l   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 4   ┆ 7   │
│ 2   ┆ 5   ┆ 8   │
│ 3   ┆ 6   ┆ 9   │
└─────┴─────┴─────┘

I can filter for a particular row doing it one column at at time, i.e.:

df = df.filter(
  (polars.col('j') == 2) &
  (polars.col('k') == 5) &
  (polars.col('l') == 8)
  )

shape: (1, 3)
┌─────┬─────┬─────┐
│ j   ┆ k   ┆ l   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 2   ┆ 5   ┆ 8   │
└─────┴─────┴─────┘

I'd like to compare to the list instead though (so I can avoid listing each column and to accommodate variable column DataFrames), e.g. something like:

df = df.filter(
    polars.concat_list(polars.all()) == [2, 5, 8]
    )

...
exceptions.ArrowErrorException: NotYetImplemented("Casting from Int64 to LargeList(Field { name: \"item\", data_type: Int64, is_nullable: true, metadata: {} }) not supported")

Any ideas why the above is throwing the exception?

I can build the expression manually:

df = df.filter(
  functools.reduce(lambda a, e: a & e, (polars.col(c) == v for c, v in zip(df.columns, [2, 5, 8])))
  )

but I was hoping there's a way to compare lists directly - e.g. as if I had this DataFrame originally:

df = polars.DataFrame(dict(j=[
  [1,4,7],
  [2,5,8],
  [3,6,9],
  ]))

shape: (3, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [1, 4, 7] │
│ [2, 5, 8] │
│ [3, 6, 9] │
└───────────┘

and wanted to find the row which matches [2, 5, 8]. Any hints?


Solution

  • You can pass multiple conditions to .all_horizontal() instead of functools.reduce

    For a list column, you can compare the values at each index with .list.get():

    df.filter(
       pl.all_horizontal(
          pl.col("j").list.get(n) == row[n]
          for row in [[2, 5, 8]]
          for n in range(len(row))
       )
    )
    
    shape: (1, 1)
    ┌───────────┐
    │ j         │
    │ ---       │
    │ list[i64] │
    ╞═══════════╡
    │ [2, 5, 8] │
    └───────────┘
    

    I'm not sure why this doesn't work:

    df.filter(pl.col("j") == pl.lit([[2, 5, 8]]))
    
    shape: (0, 1)
    ┌───────────┐
    │ j         │
    │ ---       │
    │ list[i64] │
    ╞═══════════╡
    └───────────┘
    

    For regular columns, you could modify your example:

    df.filter(
       pl.all_horizontal(
          pl.col(col) == value
          for col, value in dict(zip(df.columns, [2, 5, 8])).items()
       )
    )