pythondataframepython-polars

Sort Polars Dataframe columns based on row data


I have this data:

import polars as pl

pl.DataFrame({
    'region': ['EU', 'ASIA', 'AMER', 'Year'],
    'Share': [99, 6, -30, 2020],
    'Ration': [70, 4, -10, 2019],
    'Lots': [70, 4, -10, 2018],
    'Stake': [80, 5, -20, 2021],
})
# shape: (4, 5)
# ┌────────┬───────┬────────┬──────┬───────┐
# │ region ┆ Share ┆ Ration ┆ Lots ┆ Stake │
# │ ---    ┆ ---   ┆ ---    ┆ ---  ┆ ---   │
# │ str    ┆ i64   ┆ i64    ┆ i64  ┆ i64   │
# ╞════════╪═══════╪════════╪══════╪═══════╡
# │ EU     ┆ 99    ┆ 70     ┆ 70   ┆ 80    │
# │ ASIA   ┆ 6     ┆ 4      ┆ 4    ┆ 5     │
# │ AMER   ┆ -30   ┆ -10    ┆ -10  ┆ -20   │
# │ Year   ┆ 2020  ┆ 2019   ┆ 2018 ┆ 2021  │
# └────────┴───────┴────────┴──────┴───────┘

I want to order the columns based on the Year row, while leaving the region column first. So ideally I am looking for this:

shape: (4, 5)
┌────────┬──────┬────────┬───────┬───────┐
│ region ┆ Lots ┆ Ration ┆ Share ┆ Stake │
│ ---    ┆ ---  ┆ ---    ┆ ---   ┆ ---   │
│ str    ┆ i64  ┆ i64    ┆ i64   ┆ i64   │
╞════════╪══════╪════════╪═══════╪═══════╡
│ EU     ┆ 70   ┆ 70     ┆ 99    ┆ 80    │
│ ASIA   ┆ 4    ┆ 4      ┆ 6     ┆ 5     │
│ AMER   ┆ -10  ┆ -10    ┆ -30   ┆ -20   │
│ Year   ┆ 2018 ┆ 2019   ┆ 2020  ┆ 2021  │
└────────┴──────┴────────┴───────┴───────┘

How can this be achieved? I tried using polars' sort function, but could not get it to do what I needed.


Solution

  • .sort() works on rows.

    You could reshape with .unpivot(), .sort() and then .pivot() back to the wide format.

    (df.with_row_index()
       .unpivot(index=["index", "region"])
       .sort(pl.col("value").filter(region="Year").first().over("variable"))
       .pivot("variable", index="region", values="value")
    )
    
    shape: (4, 5)
    ┌────────┬──────┬────────┬───────┬───────┐
    │ region ┆ Lots ┆ Ration ┆ Share ┆ Stake │
    │ ---    ┆ ---  ┆ ---    ┆ ---   ┆ ---   │
    │ str    ┆ i64  ┆ i64    ┆ i64   ┆ i64   │
    ╞════════╪══════╪════════╪═══════╪═══════╡
    │ EU     ┆ 70   ┆ 70     ┆ 99    ┆ 80    │
    │ ASIA   ┆ 4    ┆ 4      ┆ 6     ┆ 5     │
    │ AMER   ┆ -10  ┆ -10    ┆ -30   ┆ -20   │
    │ Year   ┆ 2018 ┆ 2019   ┆ 2020  ┆ 2021  │
    └────────┴──────┴────────┴───────┴───────┘
    

    Or do it at the "Python level" and sort the Series by their last value.

    pl.DataFrame(
        [df.select("region").to_series()] +
        sorted(df.drop("region"), key=pl.Series.last)
    )