With the help of claude sonnet 4, I cooked up this function, which I hope does what I asked it to do.
def has_duplicates_early_exit(df: pl.LazyFrame, subset: list[str]) -> bool:
"""Can exit early when first duplicate is found"""
return df.select(
pl.struct(subset).is_duplicated().any()
).collect().item()
Is this the most efficient you can do?
You can't make it short circuit when the duplicates are early unless you write a plugin or use map_batches
and implement the short circuit manually (which will be slower, except in extreme examples, because it means using a python loop).
You can prove this to yourself by making 2 big dfs one where the duplicates are very far apart and another where they're at the beginning and see the timing between the two.
early = pl.select(a=pl.concat_list(0, pl.int_ranges(0,100_000_000))).explode('a').lazy()
late = pl.select(a=pl.concat_list(pl.int_ranges(0,100_000_000),0)).explode('a').lazy()
%%timeit
has_duplicates_early_exit(early, 'a')
# 5.12 s ± 290 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%%timeit
has_duplicates_early_exit(late , 'a')
# 5.14 s ± 157 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
These are essentially the same. The reason early isn't faster is because it doesn't short circuit. It can't short circuit because is_duplicated
doesn't know you've chained .any
so even when it finds an early true
it doesn't know it can stop.
You could either make a rust plugin that is designed to short circuit or use this python implementation to see short circuiting in action (don't really use this python implementation in real life):
import time
def early_is_dup(s:pl.Series)->pl.Series:
start = time.time()
for i, left_val in enumerate(s):
for j, right_val in enumerate(s[i+1:]):
if time.time()-start>10:
print(f"took too long, at {i=} {j=}")
return pl.Series([None], dtype=pl.Boolean)
if left_val==right_val:
return pl.Series([True])
return pl.Series([False])
early.select(pl.struct('a').map_batches(early_is_dup)).collect()
# this takes 0.0sec
late.select(pl.struct('a').map_batches(early_is_dup)).collect()
# took too long, at i=0 j=27725000
In the extreme case of the first two elements being the same, this is instant. If the duplicates are far apart then it's going to take a prohibitive amount of time.
If you're only dealing with a single column subset then this method is very fast for both early and late:
early.select(pl.col('a').unique().len()!=pl.col('a').len()).collect()
late.select(pl.col('a').unique().len()!=pl.col('a').len()).collect()
# each take about 0.7sec
However if I do
early.select(pl.struct('a').unique().len()!=pl.struct('a').len()).collect()
late.select(pl.struct('a').unique().len()!=pl.struct('a').len()).collect()
# each take about 1m8s
Here's a trick for multiple columns which beats the is_duplicated
approach but is slower than the one column method above
early.select(pl.struct('a').sort().rle_id().max()+1!=pl.struct('a').len()).collect()
late.select(pl.struct('a').sort().rle_id().max()+1!=pl.struct('a').len()).collect()
# each take about 3s
It works by first sorting the data, then getting its rle_id
. If everything is unique then there are no runs of data and the max rle_id
will be 1 less (because it starts at 0) than the len
of the column.
Putting it altogether you could make your function:
def has_duplicates_early_exit(df: pl.LazyFrame, subset: list[str]) -> bool:
"""Can't exit early but is faster anyway"""
if len(subset) == 1:
return (
df.select(pl.col(subset[0]).unique().len() != pl.col(subset[0]).len())
.collect()
.item()
)
else:
return (
df.select(
pl.struct(subset).sort().rle_id().max() + 1 != pl.struct(subset).len()
)
.collect()
.item()
)