I have a dataset with a schema like:
df = pl.DataFrame(
{
"target": [
[1.0, 1.0, 0.0],
[1.0, 1.0, 0.1],
[1.0, 1.0, 0.2],
[1.0, 1.0, 0.8],
[1.0, 1.0, 0.9],
[1.0, 1.0, 1.0],
],
"feature": ["a", "b", "c", "d", "e", "f"],
},
schema={
"target": pl.Array(pl.Float32, 3),
"feature": pl.String,
},
)
If I make a histogram of the target-z values it looks like:
I want to resample the data so its flat along z.
I managed to do it in a hacky-many-steps way (also very slow). I was wondering if people could suggest a cleaner (and more efficient) way?
What I am doing is:
bins = 2 # Use e.g. 100 or larger in reality
z = df.select(z=pl.col("target").arr.get(2))
z_min = z.min()
z_max = z.max()
breaks = np.linspace(z_min, z_max, num=bins+1)
counts = (
df.with_columns(bin=pl.col("target").arr.get(2).cut(breaks))
.with_columns(counter=pl.int_range(pl.len()).over("bin"))
.group_by("bin")
.agg(pl.col("counter").max())
.filter(pl.col("counter") > 0) # <- Nasty way of filtering the (-inf, min] bin
.select(pl.col("counter").min())
).item()
df = (
df.with_columns(bin=pl.col("target").arr.get(2).cut(breaks))
.with_columns(counter=pl.int_range(pl.len()).over("bin"))
.filter(pl.col("counter") <= counts)
.select("target", "feature")
)
Do people have any suggestions?
I don't think you can avoid those three steps for resampling (although depending on your use case you could try to transform the data instead)
You can optimize that code a bit though,
import polars as pl
import numpy as np
# Some random mocked data
rng = np.random.default_rng()
df = pl.DataFrame({'z': rng.lognormal(size=100_000) - 0.5}).filter(pl.col('z').is_between(0.0, 1.0))
z = pl.col('z')
# Create the bins using polars, and only once
cuts = df.select(pl.linear_space(z.min(), z.max(), 99, closed='none'))['z']
df = df.with_columns(bin=z.cut(cuts))
# just use len() instead of range+max()
counts = (
df
.group_by("bin")
.len()
.select(pl.col("len").min())
).item()
# take the head of each group or sample
result = (
df
.group_by('bin')
# .head(counts) # You can just use this instead of .map_groups(...sample(counts)),
# and head() is closer to what you had in the original, but
# taking only the head() may bias the data if the order is not random
.map_groups(lambda df: df.sample(counts))
.drop('bin')
)
print(result)