Suppose we have a Polars frame something like this
lf = pl.LazyFrame([
pl.Series("a", ...),
pl.Series("b", ...),
pl.Series("c", ...),
pl.Series("i", ...)
])
and a function something like this
def update(a, b, c, i):
s = a + b + c + i
a /= s
b /= s
c /= s
return a, b, c
that depends of elements of columns a, b, c
and also i
.
How can we update each row of a frame using the function?
We could use with_columns
to update the rows of each column independently but how can we do it with the dependency between columns?
Edit
In response to comments from @roman let's tighten up the question.
Use this LazyFrame
lf = pl.LazyFrame(
[
pl.Series("a", [1, 2, 3, 4], dtype=pl.Int8),
pl.Series("b", [5, 6, 7, 8], dtype=pl.Int8),
pl.Series("c", [9, 0, 1, 2], dtype=pl.Int8),
pl.Series("i", [3, 4, 5, 6], dtype=pl.Int8),
pl.Series("o", [7, 8, 9, 0], dtype=pl.Int8),
]
)
We want to update columns a, b & c in a way that depends on column i. Column o should be unaffected. We have a function that takes the values a, b, c & i and returns a, b, c & i, where the first three have been updated but i remains the same as the input. After updating, all columns should have the same dtype as before.
The closest we can get is using an update function like this.
def update(args):
s = args["a"] + args["b"] + args["c"] + args["i"]
args["a"] /= s
args["b"] /= s
args["c"] /= s
return args.values()
and applying it like this
(
lf.select(
pl.struct(pl.col("a", "b", "c", "i"))
.map_elements(update, return_dtype=pl.List(pl.Float64))
.list.to_struct(fields=["a", "b", "c", "i"])
.alias("result"),
)
.unnest("result")
.collect()
)
But this has some problems.
Is there a better way?
update
lf.select(
pl.exclude(["a","b","c"]),
pl
.struct(pl.all()).map_elements(update, return_dtype=pl.List(pl.Float64))
.list.to_struct(fields=["a","b","c"])
.alias("result")
).unnest("result").collect()
shape: (4, 5)
┌─────┬─────┬──────────┬──────────┬────────┐
│ i ┆ o ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i8 ┆ i8 ┆ f64 ┆ f64 ┆ f64 │
╞═════╪═════╪══════════╪══════════╪════════╡
│ 3 ┆ 7 ┆ 0.055556 ┆ 0.277778 ┆ 0.5 │
│ 4 ┆ 8 ┆ 0.166667 ┆ 0.5 ┆ 0.0 │
│ 5 ┆ 9 ┆ 0.1875 ┆ 0.4375 ┆ 0.0625 │
│ 6 ┆ 0 ┆ 0.2 ┆ 0.4 ┆ 0.1 │
└─────┴─────┴──────────┴──────────┴────────┘
Or you can update your function to return dictionary and skip the lists:
def update(args):
s = args["a"] + args["b"] + args["c"] + args["i"]
args["a"] /= s
args["b"] /= s
args["c"] /= s
return args
lf.select(
pl
.struct(pl.all()).map_elements(update, return_dtype=pl.Struct)
.alias("result")
).unnest("result").collect()
shape: (4, 5)
┌──────────┬──────────┬────────┬─────┬─────┐
│ a ┆ b ┆ c ┆ i ┆ o │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ f64 ┆ f64 ┆ f64 ┆ i64 ┆ i64 │
╞══════════╪══════════╪════════╪═════╪═════╡
│ 0.055556 ┆ 0.277778 ┆ 0.5 ┆ 3 ┆ 7 │
│ 0.166667 ┆ 0.5 ┆ 0.0 ┆ 4 ┆ 8 │
│ 0.1875 ┆ 0.4375 ┆ 0.0625 ┆ 5 ┆ 9 │
│ 0.2 ┆ 0.4 ┆ 0.1 ┆ 6 ┆ 0 │
└──────────┴──────────┴────────┴─────┴─────┘
original In general, I'd advice agains using python function and try to stay within pure polars expressions. So in your case it could look, for example, like this:
lf = pl.LazyFrame([
pl.Series("a", [1,2,3]),
pl.Series("b", [2,3,4]),
pl.Series("c", [5,6,7]),
pl.Series("i", [7,8,7])
])
(
lf
.with_columns(pl.exclude("i") / pl.sum_horizontal(pl.all()))
)
shape: (3, 4)
┌──────────┬──────────┬──────────┬─────┐
│ a ┆ b ┆ c ┆ i │
│ --- ┆ --- ┆ --- ┆ --- │
│ f64 ┆ f64 ┆ f64 ┆ i64 │
╞══════════╪══════════╪══════════╪═════╡
│ 0.066667 ┆ 0.133333 ┆ 0.333333 ┆ 7 │
│ 0.105263 ┆ 0.157895 ┆ 0.315789 ┆ 8 │
│ 0.142857 ┆ 0.190476 ┆ 0.333333 ┆ 7 │
└──────────┴──────────┴──────────┴─────┘
if you really want to pass the function, you can do it:
def update(row):
a,b,c,i = [row[x] for x in list("abci")]
s = a + b + c + i
a /= s
b /= s
a /= s
return a, b, c
lf.select(
pl.col.i,
pl
.struct(pl.all()).map_elements(update, return_dtype=pl.List(pl.Float64))
.list.to_struct(fields=["a","b","c"])
.alias("result")
).unnest("result")
shape: (3, 4)
┌─────┬──────────┬──────────┬─────┐
│ i ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ f64 ┆ f64 ┆ f64 │
╞═════╪══════════╪══════════╪═════╡
│ 7 ┆ 0.004444 ┆ 0.133333 ┆ 5.0 │
│ 8 ┆ 0.00554 ┆ 0.157895 ┆ 6.0 │
│ 7 ┆ 0.006803 ┆ 0.190476 ┆ 7.0 │
└─────┴──────────┴──────────┴─────┘