pythonpython-polarspolars

How to repeat and truncate list elements to a fixed length


I have data that looks like:

lf = pl.LazyFrame(
    {
        "points": [
            [
                [1.0, 2.0],
            ],
            [
                [3.0, 4.0],
                [5.0, 6.0],
            ],
            [
                [7.0, 8.0],
                [9.0, 10.0],
                [11.0, 12.0],
            ],
        ],
        "other": ["foo", "bar", "baz"],
    },
    schema={
        "points": pl.List(pl.Array(pl.Float32, 2)),
        "other": pl.String,
    },
)

And I want to make all lists have the same number of elements. If it currently has more than I need, it should truncate. If it has less than I need, it should repeat itself in order until it has enough.

I managed to get it working, but I feel I am jumping through hoops. Is there a cleaner way of doing this? Maybe with gather?

target_length = 3

result = (
    lf.with_columns(
        needed=pl.lit(target_length).truediv(pl.col("points").list.len()).ceil()
    )
    .with_columns(
        pl.col("points")
        .repeat_by("needed")
        .list.eval(pl.element().explode())
        .list.head(target_length)
    )
    .drop("needed")
)

EDIT

The method above works for toy examples, but when I try to use it in my real dataset, it fails with:

pyo3_runtime.PanicException: Polars' maximum length reached. Consider installing 'polars-u64-idx'.

I haven't been able to make a MRE for this, but my data has 4 million rows, and the "points" list on each row has between 1 and 8000 elements (and I'm trying to pad/truncate to 800 elements). These all seem pretty small, I don't see how a maximum u32 length is reached.

I appreciate any alternative approaches I can try.

The closest I have (which doesn't panic) is:

But this doesn't pad repeating the list in order. It just pads repeating the last element.

target_length = 3

result = (
    lf.with_columns(
        pl.col("points")
        .list.gather(
            pl.int_range(target_length),
            null_on_oob=True,
        )
        .list.eval(pl.element().forward_fill())
    )
    .drop("needed")
)

Solution

  • The repr defaults for lists are quite small, so we will increase them for the example.

    pl.Config(fmt_table_cell_list_len=8, fmt_str_lengths=120)
    

    If you use pl.int_ranges() (plural) and modulo arithmetic, you can generate the indices.

    target_length = 5
    
    lf.select(pl.int_ranges(target_length) % pl.col("points").list.len()).collect()
    
    shape: (3, 1)
    ┌─────────────────┐
    │ literal         │
    │ ---             │
    │ list[i64]       │
    ╞═════════════════╡
    │ [0, 0, 0, 0, 0] │
    │ [0, 1, 0, 1, 0] │
    │ [0, 1, 2, 0, 1] │
    └─────────────────┘
    

    Which you can pass to .list.gather()

    lf.with_columns(
        pl.col("points").list.gather(
            pl.int_ranges(target_length) % pl.col("points").list.len()
        )    
    ).collect()
    
    shape: (3, 2)
    ┌──────────────────────────────────────────────────────────────────┬───────┐
    │ points                                                           ┆ other │
    │ ---                                                              ┆ ---   │
    │ list[array[f32, 2]]                                              ┆ str   │
    ╞══════════════════════════════════════════════════════════════════╪═══════╡
    │ [[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]]     ┆ foo   │
    │ [[3.0, 4.0], [5.0, 6.0], [3.0, 4.0], [5.0, 6.0], [3.0, 4.0]]     ┆ bar   │
    │ [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0], [7.0, 8.0], [9.0, 10.0]] ┆ baz   │
    └──────────────────────────────────────────────────────────────────┴───────┘