pythonpython-polars

Polars Python: Filter list column using a boolean list column, but keeping list size


I would like to get elements from a list dtype column using another boolean list column and keeping the original size of the list (as oppose to this solution).

Starting from this dataframe:

df = pl.DataFrame({
    'identity_vector': [[True, False], [False, True]],
    'string_vector': [['name1', 'name2'], ['name3', 'name4']]
})

shape: (2, 2)
┌─────────────────┬────────────────────┐
│ identity_vector ┆ string_vector      │
│ ---             ┆ ---                │
│ list[bool]      ┆ list[str]          │
╞═════════════════╪════════════════════╡
│ [true, false]   ┆ ["name1", "name2"] │
│ [false, true]   ┆ ["name3", "name4"] │
└─────────────────┴────────────────────┘

The objective is to get this output:

shape: (2, 3)
┌─────────────────┬────────────────────┬──────────────────┐
│ identity_vector ┆ string_vector      ┆ filtered_strings │
│ ---             ┆ ---                ┆ ---              │
│ list[bool]      ┆ list[str]          ┆ list[str]        │
╞═════════════════╪════════════════════╪══════════════════╡
│ [true, false]   ┆ ["name1", "name2"] ┆ ["name1", null]  │
│ [false, true]   ┆ ["name3", "name4"] ┆ [null, "name4"]  │
└─────────────────┴────────────────────┴──────────────────┘

Which I can get using the block of code below and map_elements, but the solution is sub-optimal for performance reasons:

df.with_columns(
    filtered_strings=pl.struct(["string_vector", "identity_vector"]).map_elements(
        lambda row: [s if keep else None for s, keep in zip(row["string_vector"], row["identity_vector"])]
    )
)

Do you have any suggestion on how to improve the performance of this process?


Solution

  • Kind of standard pl.Expr.explode() / calculate / pl.Expr.implode() route:

    df.with_columns(
        pl.when(
            pl.col.identity_vector.explode()
        ).then(
            pl.col.string_vector.explode()
        ).otherwise(None)
        .implode()
        .over(pl.int_range(pl.len()))
        .alias("filtered_strings")
    )
    
    shape: (2, 3)
    ┌─────────────────┬────────────────────┬──────────────────┐
    │ identity_vector ┆ string_vector      ┆ filtered_strings │
    │ ---             ┆ ---                ┆ ---              │
    │ list[bool]      ┆ list[str]          ┆ list[str]        │
    ╞═════════════════╪════════════════════╪══════════════════╡
    │ [true, false]   ┆ ["name1", "name2"] ┆ ["name1", null]  │
    │ [false, true]   ┆ ["name3", "name4"] ┆ [null, "name4"]  │
    └─────────────────┴────────────────────┴──────────────────┘
    

    There're also other possible approaches, for example using pl.Expr.list.eval() and pl.Expr.list.gather()

    df.with_columns(
        pl.col.string_vector.list.gather(
            pl.col.identity_vector.list.eval(
                pl.when(pl.element()).then(pl.int_range(pl.len()))
            )
        ).alias("filtered_strings")
    )
    
    shape: (2, 3)
    ┌─────────────────┬────────────────────┬──────────────────┐
    │ identity_vector ┆ string_vector      ┆ filtered_strings │
    │ ---             ┆ ---                ┆ ---              │
    │ list[bool]      ┆ list[str]          ┆ list[str]        │
    ╞═════════════════╪════════════════════╪══════════════════╡
    │ [true, false]   ┆ ["name1", "name2"] ┆ ["name1", null]  │
    │ [false, true]   ┆ ["name3", "name4"] ┆ [null, "name4"]  │
    └─────────────────┴────────────────────┴──────────────────┘
    

    Or, if you know length of your lists or it's relatively small, you can create columns for each list index and then use pl.Expr.list.get() and pl.concat_list().

    l = 2
    df.with_columns(
        filtered_strings = pl.concat_list(
            pl.when(
                pl.col.identity_vector.list.get(i)
            ).then(
                pl.col.string_vector.list.get(i)
            )
            for i in range(2)
        )
    )
    
    shape: (2, 3)
    ┌─────────────────┬────────────────────┬──────────────────┐
    │ identity_vector ┆ string_vector      ┆ filtered_strings │
    │ ---             ┆ ---                ┆ ---              │
    │ list[bool]      ┆ list[str]          ┆ list[str]        │
    ╞═════════════════╪════════════════════╪══════════════════╡
    │ [true, false]   ┆ ["name1", "name2"] ┆ ["name1", null]  │
    │ [false, true]   ┆ ["name3", "name4"] ┆ [null, "name4"]  │
    └─────────────────┴────────────────────┴──────────────────┘
    

    All solutions use pl.when() to set value to null when condition is not met.