pythonpython-polars

In Polars, how can you update several columns simultaneously?


Suppose we have a Polars frame something like this

lf = pl.LazyFrame([
    pl.Series("a", ...),
    pl.Series("b", ...),
    pl.Series("c", ...),
    pl.Series("i", ...)
])

and a function something like this

def update(a, b, c, i):
    s = a + b + c + i
    a /= s
    b /= s
    c /= s
    return a, b, c

that depends of elements of columns a, b, c and also i.

How can we update each row of a frame using the function?

We could use with_columns to update the rows of each column independently but how can we do it with the dependency between columns?

Edit

In response to comments from @roman let's tighten up the question.

Use this LazyFrame

lf = pl.LazyFrame(
    [
        pl.Series("a", [1, 2, 3, 4], dtype=pl.Int8),
        pl.Series("b", [5, 6, 7, 8], dtype=pl.Int8),
        pl.Series("c", [9, 0, 1, 2], dtype=pl.Int8),
        pl.Series("i", [3, 4, 5, 6], dtype=pl.Int8),
        pl.Series("o", [7, 8, 9, 0], dtype=pl.Int8),
    ]
)

We want to update columns a, b & c in a way that depends on column i. Column o should be unaffected. We have a function that takes the values a, b, c & i and returns a, b, c & i, where the first three have been updated but i remains the same as the input. After updating, all columns should have the same dtype as before.

The closest we can get is using an update function like this.

def update(args):
    s = args["a"] + args["b"] + args["c"] + args["i"]
    args["a"] /= s
    args["b"] /= s
    args["c"] /= s
    return args.values()

and applying it like this

(
    lf.select(
        pl.struct(pl.col("a", "b", "c", "i"))
        .map_elements(update, return_dtype=pl.List(pl.Float64))
        .list.to_struct(fields=["a", "b", "c", "i"])
        .alias("result"),
    )
    .unnest("result")
    .collect()
)

But this has some problems.

  1. We have lost column o.
  2. Column i has become Float64
  3. It's pretty ugly.

Is there a better way?


Solution

  • update

    lf.select(
        pl.exclude(["a","b","c"]),
        pl
        .struct(pl.all()).map_elements(update, return_dtype=pl.List(pl.Float64))
        .list.to_struct(fields=["a","b","c"])
        .alias("result")
    ).unnest("result").collect()
    
    shape: (4, 5)
    ┌─────┬─────┬──────────┬──────────┬────────┐
    │ i   ┆ o   ┆ a        ┆ b        ┆ c      │
    │ --- ┆ --- ┆ ---      ┆ ---      ┆ ---    │
    │ i8  ┆ i8  ┆ f64      ┆ f64      ┆ f64    │
    ╞═════╪═════╪══════════╪══════════╪════════╡
    │ 3   ┆ 7   ┆ 0.055556 ┆ 0.277778 ┆ 0.5    │
    │ 4   ┆ 8   ┆ 0.166667 ┆ 0.5      ┆ 0.0    │
    │ 5   ┆ 9   ┆ 0.1875   ┆ 0.4375   ┆ 0.0625 │
    │ 6   ┆ 0   ┆ 0.2      ┆ 0.4      ┆ 0.1    │
    └─────┴─────┴──────────┴──────────┴────────┘
    

    Or you can update your function to return dictionary and skip the lists:

    def update(args):
        s = args["a"] + args["b"] + args["c"] + args["i"]
        args["a"] /= s
        args["b"] /= s
        args["c"] /= s
        return args
    
    lf.select(
        pl
        .struct(pl.all()).map_elements(update, return_dtype=pl.Struct)
        .alias("result")
    ).unnest("result").collect()
    
    shape: (4, 5)
    ┌──────────┬──────────┬────────┬─────┬─────┐
    │ a        ┆ b        ┆ c      ┆ i   ┆ o   │
    │ ---      ┆ ---      ┆ ---    ┆ --- ┆ --- │
    │ f64      ┆ f64      ┆ f64    ┆ i64 ┆ i64 │
    ╞══════════╪══════════╪════════╪═════╪═════╡
    │ 0.055556 ┆ 0.277778 ┆ 0.5    ┆ 3   ┆ 7   │
    │ 0.166667 ┆ 0.5      ┆ 0.0    ┆ 4   ┆ 8   │
    │ 0.1875   ┆ 0.4375   ┆ 0.0625 ┆ 5   ┆ 9   │
    │ 0.2      ┆ 0.4      ┆ 0.1    ┆ 6   ┆ 0   │
    └──────────┴──────────┴────────┴─────┴─────┘
    

    original In general, I'd advice agains using python function and try to stay within pure polars expressions. So in your case it could look, for example, like this:

    lf = pl.LazyFrame([
        pl.Series("a", [1,2,3]),
        pl.Series("b", [2,3,4]),
        pl.Series("c", [5,6,7]),
        pl.Series("i", [7,8,7])
    ])
    
    (
        lf
        .with_columns(pl.exclude("i") / pl.sum_horizontal(pl.all()))
    )
    
    shape: (3, 4)
    ┌──────────┬──────────┬──────────┬─────┐
    │ a        ┆ b        ┆ c        ┆ i   │
    │ ---      ┆ ---      ┆ ---      ┆ --- │
    │ f64      ┆ f64      ┆ f64      ┆ i64 │
    ╞══════════╪══════════╪══════════╪═════╡
    │ 0.066667 ┆ 0.133333 ┆ 0.333333 ┆ 7   │
    │ 0.105263 ┆ 0.157895 ┆ 0.315789 ┆ 8   │
    │ 0.142857 ┆ 0.190476 ┆ 0.333333 ┆ 7   │
    └──────────┴──────────┴──────────┴─────┘
    

    if you really want to pass the function, you can do it:

    def update(row):
        a,b,c,i = [row[x] for x in list("abci")]
        s = a + b + c + i
        a /= s
        b /= s
        a /= s
        return a, b, c
    
    lf.select(
        pl.col.i,
        pl
        .struct(pl.all()).map_elements(update, return_dtype=pl.List(pl.Float64))
        .list.to_struct(fields=["a","b","c"])
        .alias("result")
    ).unnest("result")
    
    shape: (3, 4)
    ┌─────┬──────────┬──────────┬─────┐
    │ i   ┆ a        ┆ b        ┆ c   │
    │ --- ┆ ---      ┆ ---      ┆ --- │
    │ i64 ┆ f64      ┆ f64      ┆ f64 │
    ╞═════╪══════════╪══════════╪═════╡
    │ 7   ┆ 0.004444 ┆ 0.133333 ┆ 5.0 │
    │ 8   ┆ 0.00554  ┆ 0.157895 ┆ 6.0 │
    │ 7   ┆ 0.006803 ┆ 0.190476 ┆ 7.0 │
    └─────┴──────────┴──────────┴─────┘