The task is to filter a polars dataframe df with a lot of conditions, create columns of values of for each condition, and then concat all the sub_df among all conditions.
import polars as pl
from itertools import product
df = pl.DataFrame({"A": [31,32,73,24,15,26,57,98,79,10],
"B": [11,22,53,44,53,16,27,38,49,10],
"C": [41,12,23,44,25,46,27,48,29,10],
"D": [71,52,13,34,53,36,27,48,39,10],
"E": [81,82,63,24,15,56,47,68,49,10]})
a12_l = [[13,16], [12,72], [18,22]]
b12_l = [[11,13], [14,55], [22,55]]
c12_l = [[23,76], [13,65], [23,56]]
d12_l = [[21,42], [18,25], [25,35]]
pl.concat([df.filter((pl.col('A').is_between(a1,a2)) & (pl.col('B').is_between(b1,b2)) &
(pl.col('C').is_between(c1,c2)) & (pl.col('D').is_between(d1,d2))).\
with_columns(a1=a1, a2=a2, b1=b1, b2=b2, c1=c1, c2=c2, d1=d1, d2=d2)
for [a1, a2], [b1, b2], [c1, c2], [d1, d2] in product(a12_l, b12_l, c12_l, d12_l)])
the df has millions of rows and the combination of filters is up to 200000, so the whole process is quite slow.
You can calculate every compare and save the results in a DataFrame, and then use all_horizontal()
to calculate the logic and:
import polars as pl
from itertools import product, chain
from collections import defaultdict
import numpy as np
n = 50
df = pl.DataFrame(np.random.randint(0, 100, size=(4, n)), schema=['A', 'B', 'C', 'D'])
a12_l = [[13,16], [12,72], [18,22], [80, 90]]
b12_l = [[11,13], [14,55], [22,55], [70, 80]]
c12_l = [[23,76], [13,65], [23,56], [60, 80]]
d12_l = [[21,42], [18,25], [25,35], [50, 60]]
conditions = [a12_l, b12_l, c12_l, d12_l]
dfs = pl.concat([
df.select(pl.col(c).is_between(vmin, vmax).alias(f'{c}{i}') for i, (vmin, vmax) in enumerate(ranges))
for c, ranges in zip(df.columns, conditions)
], how='horizontal')
exprs_index = []
for idx in product(*[range(len(c)) for c in conditions]):
expr = pl.all_horizontal([f"{c}{i}" for c, i in zip(df.columns, idx)]).arg_true().implode()
exprs_index.append(expr)
df_ranges = pl.DataFrame(
[list(chain.from_iterable(cond)) for cond in product(*conditions)],
orient='row',
schema=chain.from_iterable([[f'{c.lower()}1', f'{c.lower()}2'] for c in df.columns])
)
row_index = pl.concat(
[dfs.select(pl.concat(exprs_index).alias('index')), df_ranges], how='horizontal'
).filter(pl.col('index').list.len() > 0).explode('index')
df2 = df.with_row_index().join(row_index, how='right', on='index').drop('index')