pythonjoinoptimizationmathematical-optimizationself-join

How to identify cases where both elements of a pair are greater than others' respective elements in the set?


I have a case where I have a list of pairs, each with two numerical values. I want to find the subset of these elements containing only those pairs that are not exceeded by both elements of another (let's say "eclipsed" by another).

For example, the pair (1,2) is eclipsed by (4,5) because both elements are less than the respective elements in the other pair.

Also, (1,2) is considered eclipsed by (1,3) because while the first element is equal to the other and the second element is less than the other's.

However the pair (2, 10) is not eclipsed by (9, 9) because only one of its elements is exceeded by the other's.

Cases where the pairs are identical should be reduced to just one (duplicates removed).

Ultimately, I am looking to reduce the list of pairs to a subset where only pairs that were not eclipsed by any others remain.

For example, take the following list:

(1,2)
(1,5)
(2,2)
(1,2)
(2,2)
(9,1)
(1,1)

This should be reduced to the following:

(1,5)
(2,2)
(9,1)

My initial implementation of this in python was the following, using polars:

import polars as pl

pairs_list = [
    (1,2),
    (1,5),
    (2,2),
    (1,2),
    (2,2),
    (9,1),
    (1,1),
]

# tabulate pair elements as 'a' and 'b'
pairs = pl.DataFrame(
    data=pairs_list,
    schema={'a': pl.UInt32, 'b': pl.UInt32},
    orient='row',
)

# eliminate any duplicate pairs
unique_pairs = pairs.unique()

# self join so every pair can be compared (except against itself)
comparison_inputs = (
    unique_pairs
    .join(
        unique_pairs,
        how='cross',
        suffix='_comp',
    )
    .filter(
        pl.any_horizontal(
            pl.col('a') != pl.col('a_comp'),
            pl.col('b') != pl.col('b_comp'),
        )
    )
)

# flag pairs that were eclipsed by others
comparison_results = (
    comparison_inputs
    .with_columns(
        pl.all_horizontal(
            pl.col('a') <= pl.col('a_comp'),
            pl.col('b') <= pl.col('b_comp'),
        )
        .alias('is_eclipsed')
    )
)

# remove pairs that were eclipsed by at least one other
principal_pairs = (
    comparison_results
    .group_by('a', 'b')
    .agg(pl.col('is_eclipsed').any())
    .filter(is_eclipsed=False)
    .select('a', 'b')
)

While this does appear to work, it is computationally infeasible for large datasets due to the sheer size of the self-joined table.

I have considered filtering the comparison_inputs table down by removing redundant reversed comparisons, e.g., pair X vs pair Y and pair Y vs pair X don't both need to be in the table as they currently are, but changing that requires an additional condition in each comparison to report which element was eclipsed in the comparison and only reduces the dataset in half, which isn't that significant.

I have found I can reduce the needed comparisons substantially by doing a window function filter that filters to only the max b for each a and vice versa before doing the self joining step. In other words:

unique_pairs = (
    pairs
    .unique()
    .filter(a = pl.col('a').last().over('b', order_by='a')
    .filter(b = pl.col('b').last().over('a', order_by='b')

But of course this only does so much and depends on the cardinality of a and b. I still need to self-join and compare after this to get a result.

I am curious if there is already some algorithm established for calculating this and whether anyone has ideas for a more efficient method. Interested to learn more anyway. Thanks in advance.


Solution

  • What we can do from my perspective is. First, we remove duplicates and sort the pairs - First element in des order and with the ties in first element, sort by second element in des order

    unique_pairs = sorted(set(pairs), reverse=True)
    

    By keeping the condition for each pair If - b is greater than the maximum second element seen so far for all previous pairs with larger first elements, this pair cannot be eclipsed.

    from typing import List, Tuple
    import bisect
    
    def find_non_eclipsed_pairs(pairs: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
        if not pairs:
            return []
        
        unique_pairs = sorted(set(pairs), reverse=True)
        
        result = []
        max_second_elements = []
        
        for pair in unique_pairs:
            if not max_second_elements or pair[1] > max_second_elements[-1]:
                result.append(pair)
                while max_second_elements and max_second_elements[-1] <= pair[1]:
                    max_second_elements.pop()
                max_second_elements.append(pair[1])
                
        return sorted(result)
    

    Testing

    def test_pareto_pairs():
        test_cases = [
            (
                [(1,2), (1,5), (2,2), (1,2), (2,2), (9,1), (1,1)],
                [(1,5), (2,2), (9,1)]
            ),
            (
                [],
                []
            ),
            (
                [(1,1)],
                [(1,1)]
            ),
            (
                [(1,1), (2,2), (3,3), (4,4)],
                [(4,4)]
            ),
            (
                [(1,5), (5,1)],
                [(1,5), (5,1)]
            ),
            (
                [(1,1), (1,2), (2,1), (2,2), (3,1), (1,3)],
                [(1,3), (2,2), (3,1)]
            )
        ]
        
        for i, (input_pairs, expected) in enumerate(test_cases, 1):
            result = find_non_eclipsed_pairs(input_pairs)
            assert result == sorted(expected), f"Test case {i} failed: expected {expected}, got {result}"
            print(f"Test case {i} passed")
    
    if __name__ == "__main__":
        test_pareto_pairs()
        
        pairs_list = [
            (1,2),
            (1,5),
            (2,2),
            (1,2),
            (2,2),
            (9,1),
            (1,1),
        ]
        
        result = find_non_eclipsed_pairs(pairs_list)
        print("\nOriginal pairs:", pairs_list)
        print("Non-eclipsed pairs:", result)
    

    Which results

    =================== RESTART: C:/Users/Bhargav/Desktop/test.py ==================
    Test case 1 passed
    Test case 2 passed
    Test case 3 passed
    Test case 4 passed
    Test case 5 passed
    Test case 6 passed
    
    Original pairs: [(1, 2), (1, 5), (2, 2), (1, 2), (2, 2), (9, 1), (1, 1)]
    Non-eclipsed pairs: [(1, 5), (2, 2), (9, 1)]
    

    Time complexity - O(n log n) Space complexity is O(n)

    Edit: Thanks for @no comment for suggesting using sort with reverse=True

    unique_pairs = sorted(set(pairs), reverse=True)