pythonpython-polars

How to speed up the operation of repeatedly filter and create columns with filtered values?


The task is to filter a polars dataframe df with a lot of conditions, create columns of values of for each condition, and then concat all the sub_df among all conditions.

import polars as pl
from itertools import product

df = pl.DataFrame({"A": [31,32,73,24,15,26,57,98,79,10], 
                   "B": [11,22,53,44,53,16,27,38,49,10], 
                   "C": [41,12,23,44,25,46,27,48,29,10], 
                   "D": [71,52,13,34,53,36,27,48,39,10], 
                   "E": [81,82,63,24,15,56,47,68,49,10]})

a12_l = [[13,16], [12,72], [18,22]]
b12_l = [[11,13], [14,55], [22,55]]
c12_l = [[23,76], [13,65], [23,56]]
d12_l = [[21,42], [18,25], [25,35]]

pl.concat([df.filter((pl.col('A').is_between(a1,a2)) & (pl.col('B').is_between(b1,b2)) & 
                     (pl.col('C').is_between(c1,c2)) & (pl.col('D').is_between(d1,d2))).\
              with_columns(a1=a1, a2=a2, b1=b1, b2=b2, c1=c1, c2=c2, d1=d1, d2=d2) 
              for [a1, a2], [b1, b2], [c1, c2], [d1, d2] in product(a12_l, b12_l, c12_l, d12_l)])

the df has millions of rows and the combination of filters is up to 200000, so the whole process is quite slow.


Solution

  • You can calculate every compare and save the results in a DataFrame, and then use all_horizontal() to calculate the logic and:

    import polars as pl
    from itertools import product, chain
    from collections import defaultdict
    import numpy as np
    
    n = 50
    df = pl.DataFrame(np.random.randint(0, 100, size=(4, n)), schema=['A', 'B', 'C', 'D'])
    
    a12_l = [[13,16], [12,72], [18,22], [80, 90]]
    b12_l = [[11,13], [14,55], [22,55], [70, 80]]
    c12_l = [[23,76], [13,65], [23,56], [60, 80]]
    d12_l = [[21,42], [18,25], [25,35], [50, 60]]
    
    conditions = [a12_l, b12_l, c12_l, d12_l]
    
    dfs = pl.concat([
        df.select(pl.col(c).is_between(vmin, vmax).alias(f'{c}{i}') for i, (vmin, vmax) in enumerate(ranges))
            for c, ranges in zip(df.columns, conditions)
    ], how='horizontal')
    
    exprs_index = []
    for idx in product(*[range(len(c)) for c in conditions]):
        expr = pl.all_horizontal([f"{c}{i}" for c, i in zip(df.columns, idx)]).arg_true().implode()
        exprs_index.append(expr)
    
    df_ranges = pl.DataFrame(
        [list(chain.from_iterable(cond)) for cond in product(*conditions)],
        orient='row',
        schema=chain.from_iterable([[f'{c.lower()}1', f'{c.lower()}2'] for c in df.columns])
    )        
    
    row_index = pl.concat(
        [dfs.select(pl.concat(exprs_index).alias('index')), df_ranges], how='horizontal'
    ).filter(pl.col('index').list.len() > 0).explode('index')
    
    df2 = df.with_row_index().join(row_index, how='right', on='index').drop('index')