I have data that looks like:
lf = pl.LazyFrame(
{
"points": [
[
[1.0, 2.0],
],
[
[3.0, 4.0],
[5.0, 6.0],
],
[
[7.0, 8.0],
[9.0, 10.0],
[11.0, 12.0],
],
],
"other": ["foo", "bar", "baz"],
},
schema={
"points": pl.List(pl.Array(pl.Float32, 2)),
"other": pl.String,
},
)
And I want to make all lists have the same number of elements. If it currently has more than I need, it should truncate. If it has less than I need, it should repeat itself in order until it has enough.
I managed to get it working, but I feel I am jumping through hoops. Is there a cleaner way of doing this? Maybe with gather
?
target_length = 3
result = (
lf.with_columns(
needed=pl.lit(target_length).truediv(pl.col("points").list.len()).ceil()
)
.with_columns(
pl.col("points")
.repeat_by("needed")
.list.eval(pl.element().explode())
.list.head(target_length)
)
.drop("needed")
)
EDIT
The method above works for toy examples, but when I try to use it in my real dataset, it fails with:
pyo3_runtime.PanicException: Polars' maximum length reached. Consider installing 'polars-u64-idx'.
I haven't been able to make a MRE for this, but my data has 4 million rows, and the "points" list on each row has between 1 and 8000 elements (and I'm trying to pad/truncate to 800 elements). These all seem pretty small, I don't see how a maximum u32
length is reached.
I appreciate any alternative approaches I can try.
The closest I have (which doesn't panic) is:
But this doesn't pad repeating the list in order. It just pads repeating the last element.
target_length = 3
result = (
lf.with_columns(
pl.col("points")
.list.gather(
pl.int_range(target_length),
null_on_oob=True,
)
.list.eval(pl.element().forward_fill())
)
.drop("needed")
)
The repr defaults for lists are quite small, so we will increase them for the example.
pl.Config(fmt_table_cell_list_len=8, fmt_str_lengths=120)
If you use pl.int_ranges()
(plural) and modulo arithmetic, you can generate the indices.
target_length = 5
lf.select(pl.int_ranges(target_length) % pl.col("points").list.len()).collect()
shape: (3, 1)
┌─────────────────┐
│ literal │
│ --- │
│ list[i64] │
╞═════════════════╡
│ [0, 0, 0, 0, 0] │
│ [0, 1, 0, 1, 0] │
│ [0, 1, 2, 0, 1] │
└─────────────────┘
Which you can pass to .list.gather()
lf.with_columns(
pl.col("points").list.gather(
pl.int_ranges(target_length) % pl.col("points").list.len()
)
).collect()
shape: (3, 2)
┌──────────────────────────────────────────────────────────────────┬───────┐
│ points ┆ other │
│ --- ┆ --- │
│ list[array[f32, 2]] ┆ str │
╞══════════════════════════════════════════════════════════════════╪═══════╡
│ [[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]] ┆ foo │
│ [[3.0, 4.0], [5.0, 6.0], [3.0, 4.0], [5.0, 6.0], [3.0, 4.0]] ┆ bar │
│ [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0], [7.0, 8.0], [9.0, 10.0]] ┆ baz │
└──────────────────────────────────────────────────────────────────┴───────┘