pythonpython-polarspolars

How to resample a dataset to achieve a uniform distribution


I have a dataset with a schema like:

df = pl.DataFrame(
    {
        "target": [
            [1.0, 1.0, 0.0],
            [1.0, 1.0, 0.1],
            [1.0, 1.0, 0.2],
            [1.0, 1.0, 0.8],
            [1.0, 1.0, 0.9],
            [1.0, 1.0, 1.0],
        ],
        "feature": ["a", "b", "c", "d", "e", "f"],
    },
    schema={
        "target": pl.Array(pl.Float32, 3),
        "feature": pl.String,
    },
)

If I make a histogram of the target-z values it looks like: original

I want to resample the data so its flat along z.

I managed to do it in a hacky-many-steps way (also very slow). I was wondering if people could suggest a cleaner (and more efficient) way?

What I am doing is:

  1. Find the bin edges of said histogram:
bins = 2 # Use e.g. 100 or larger in reality
z = df.select(z=pl.col("target").arr.get(2))
z_min = z.min()
z_max = z.max()
breaks = np.linspace(z_min, z_max, num=bins+1)
  1. Find how many counts are in the bin with the fewest counts:
counts = (
    df.with_columns(bin=pl.col("target").arr.get(2).cut(breaks))
    .with_columns(counter=pl.int_range(pl.len()).over("bin"))
    .group_by("bin")
    .agg(pl.col("counter").max())
    .filter(pl.col("counter") > 0)  # <- Nasty way of filtering the (-inf, min] bin
    .select(pl.col("counter").min())
).item()
  1. Choose only "count" elements on each bin:
df = (
    df.with_columns(bin=pl.col("target").arr.get(2).cut(breaks))
    .with_columns(counter=pl.int_range(pl.len()).over("bin"))
    .filter(pl.col("counter") <= counts)
    .select("target", "feature")
)

This gives me: flat

Do people have any suggestions?


Solution

  • I don't think you can avoid those three steps for resampling (although depending on your use case you could try to transform the data instead)

    You can optimize that code a bit though,

    import polars as pl
    import numpy as np
    
    # Some random mocked data
    rng = np.random.default_rng()
    df = pl.DataFrame({'z': rng.lognormal(size=100_000) - 0.5}).filter(pl.col('z').is_between(0.0, 1.0))
    
    z = pl.col('z')
    
    # Create the bins using polars, and only once
    cuts = df.select(pl.linear_space(z.min(), z.max(), 99, closed='none'))['z']
    df = df.with_columns(bin=z.cut(cuts))
    
    # just use len() instead of range+max()
    counts = (
        df
        .group_by("bin")
        .len()
        .select(pl.col("len").min())
    ).item()
    
    # take the head of each group or sample
    result = (
        df
        .group_by('bin')
         # .head(counts)  # You can just use this instead of .map_groups(...sample(counts)),
         # and head() is closer to what you had in the original, but
         # taking only the head() may bias the data if the order is not random
        .map_groups(lambda df: df.sample(counts))
        .drop('bin')
    )
    print(result)