pythonalgorithmtime-complexity

What is a quick way to count the number of pairs in a list where a XOR b is greater than a AND b?


I have an array of numbers, I want to count all possible combination of pairs for which the xor operation for that pair is greater than and operation.

Example:

4,3,5,2

possible pairs are:

(4,3) -> xor=7, and = 0
(4,5) -> xor=1, and = 4
(4,2) -> xor=6, and = 0
(3,5) -> xor=6, and = 1
(3,2) -> xor=1, and = 2
(5,2) -> xor=7, and = 0

Valid pairs for which xor > and are (4,3), (4,2), (3,5), (5,2) so result is 4.

This is my program:

def solve(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans

Time complexity is O(n^2) , but my array size is 1 to 10^5 and each element in array is 1 to 2^30. So how can I reduce time complexity of this program.


Solution

  • This uses (effectively) the same algorithm as you, so it's still O(n^2), but you can speed up the operation using numpy:

    import numpy as np
    
    def npy(nums):
        xor_arr = np.bitwise_xor(nums, nums[:, None])
        and_arr = np.bitwise_and(nums, nums[:, None])
    
        return (xor_arr > and_arr).sum() // 2
    

    You could also skip numpy altogether and use numba to JIT-compile your own code before it is run.

    import numba
    
    @numba.njit
    def nba(array):
        n = len(array)
        ans = 0
        for i in range(0, n):
            p1 = array[i]
            for j in range(i, n):
                p2 = array[j]
                if p1 ^ p2 > p1 & p2:
                    ans +=1
        return ans
    

    Finally, here's my implementation of Dave's algorithm:

    from collections import defaultdict
    def new_alg(array):
        msb_num_count = defaultdict(int)
        for num in array:
            msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
            msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
        
        # Now, for each number, the count will be the sum of the numbers in all other groups
        cnt = 0
        len_all_groups = len(array)
        for group_len in msb_num_count.values():
            cnt += group_len * (len_all_groups - group_len)
    
        return cnt // 2
    

    And, as a numba-compatible function. I needed to define a get_msb since numba.njit won't handle builtin python functions

    @numba.njit
    def get_msb(num):
        msb = 0
        while num:
            msb += 1
            num = num >> 1
        return msb
    
    @numba.njit
    def new_alg_numba(array):
        msb_num_count = {}
        for num in array:
            msb = get_msb(num)
            if msb not in msb_num_count:
                msb_num_count[msb] = 0
            msb_num_count[msb] += 1
    
        # Now, for each number, the count will be the sum of the numbers in all other groups
        cnt = 0
        len_all_groups = len(array)
            
        for grp_len in msb_num_count.values():
            cnt += grp_len * (len_all_groups - grp_len)
    
        return cnt // 2
    

    Comparing runtimes, we see that the numba approach is significantly faster than the numpy approach, which is itself faster than looping in python.

    The linear-time algorithm given by Dave is faster than the numpy approach to begin with, and it starts to get faster than the numba-compiled code for inputs > ~1000 elements. The numba-compiled version of this approach is even faster -- it outpaces the numba-compiled loopy at ~100 elements.

    Kelly's excellent implementation of Dave's algorithm is on par with the numba-version of my implementation for larger inputs

    enter image description here

    (Your implementation is labelled "loopy". Other legend labels in the plot are the same as function names in my answer above. Kelly's implementation is labelled "kelly")