pythonpython-polarspolars

What is the most efficient way to check if a Polars LazyFrame has duplicates?


With the help of claude sonnet 4, I cooked up this function, which I hope does what I asked it to do.

def has_duplicates_early_exit(df: pl.LazyFrame, subset: list[str]) -> bool:
    """Can exit early when first duplicate is found"""
    return df.select(
        pl.struct(subset).is_duplicated().any()
    ).collect().item()

Is this the most efficient you can do?


Solution

  • You can't make it short circuit when the duplicates are early unless you write a plugin or use map_batches and implement the short circuit manually (which will be slower, except in extreme examples, because it means using a python loop).

    You can prove this to yourself by making 2 big dfs one where the duplicates are very far apart and another where they're at the beginning and see the timing between the two.

    early = pl.select(a=pl.concat_list(0, pl.int_ranges(0,100_000_000))).explode('a').lazy()
    late = pl.select(a=pl.concat_list(pl.int_ranges(0,100_000_000),0)).explode('a').lazy()
    
    %%timeit
    has_duplicates_early_exit(early, 'a')
    # 5.12 s ± 290 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    %%timeit
    has_duplicates_early_exit(late , 'a')
    # 5.14 s ± 157 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    These are essentially the same. The reason early isn't faster is because it doesn't short circuit. It can't short circuit because is_duplicated doesn't know you've chained .any so even when it finds an early true it doesn't know it can stop.

    You could either make a rust plugin that is designed to short circuit or use this python implementation to see short circuiting in action (don't really use this python implementation in real life):

    import time
    def early_is_dup(s:pl.Series)->pl.Series:
        start = time.time()
        for i, left_val in enumerate(s):
            for j, right_val in enumerate(s[i+1:]):
                if time.time()-start>10:
                    print(f"took too long, at {i=} {j=}")
                    return pl.Series([None], dtype=pl.Boolean)
                if left_val==right_val:
                    return pl.Series([True])
        return pl.Series([False])
    
    
    early.select(pl.struct('a').map_batches(early_is_dup)).collect()
    # this takes 0.0sec
    
    late.select(pl.struct('a').map_batches(early_is_dup)).collect()
    # took too long, at i=0 j=27725000
    

    In the extreme case of the first two elements being the same, this is instant. If the duplicates are far apart then it's going to take a prohibitive amount of time.

    Alternatives

    If you're only dealing with a single column subset then this method is very fast for both early and late:

    early.select(pl.col('a').unique().len()!=pl.col('a').len()).collect()
    late.select(pl.col('a').unique().len()!=pl.col('a').len()).collect()
    
    # each take about 0.7sec
    

    However if I do

    early.select(pl.struct('a').unique().len()!=pl.struct('a').len()).collect()
    late.select(pl.struct('a').unique().len()!=pl.struct('a').len()).collect()
    # each take about 1m8s
    

    Here's a trick for multiple columns which beats the is_duplicated approach but is slower than the one column method above

    early.select(pl.struct('a').sort().rle_id().max()+1!=pl.struct('a').len()).collect()
    late.select(pl.struct('a').sort().rle_id().max()+1!=pl.struct('a').len()).collect()
    # each take about 3s
    

    It works by first sorting the data, then getting its rle_id. If everything is unique then there are no runs of data and the max rle_id will be 1 less (because it starts at 0) than the len of the column.

    Putting it altogether you could make your function:

    def has_duplicates_early_exit(df: pl.LazyFrame, subset: list[str]) -> bool:
        """Can't exit early but is faster anyway"""
        if len(subset) == 1:
            return (
                df.select(pl.col(subset[0]).unique().len() != pl.col(subset[0]).len())
                .collect()
                .item()
            )
        else:
            return (
                df.select(
                    pl.struct(subset).sort().rle_id().max() + 1 != pl.struct(subset).len()
                )
                .collect()
                .item()
            )