I have an array of numbers, I want to count all possible combination of pairs for which the xor operation for that pair is greater than and operation.
Example:
4,3,5,2
possible pairs are:
(4,3) -> xor=7, and = 0
(4,5) -> xor=1, and = 4
(4,2) -> xor=6, and = 0
(3,5) -> xor=6, and = 1
(3,2) -> xor=1, and = 2
(5,2) -> xor=7, and = 0
Valid pairs for which xor > and are (4,3), (4,2), (3,5), (5,2) so result is 4.
This is my program:
def solve(array):
n = len(array)
ans = 0
for i in range(0, n):
p1 = array[i]
for j in range(i, n):
p2 = array[j]
if p1 ^ p2 > p1 & p2:
ans +=1
return ans
Time complexity is O(n^2) , but my array size is 1 to 10^5 and each element in array is 1 to 2^30. So how can I reduce time complexity of this program.
This uses (effectively) the same algorithm as you, so it's still O(n^2), but you can speed up the operation using numpy:
np.bitwise_xor
performs the bitwise xor operation on two arraysnp.bitwise_and
performs the bitwise and operation on two arraysa ^ a == 0
, we can simply sum the entire array and divide its result by 2 for the answer.import numpy as np
def npy(nums):
xor_arr = np.bitwise_xor(nums, nums[:, None])
and_arr = np.bitwise_and(nums, nums[:, None])
return (xor_arr > and_arr).sum() // 2
You could also skip numpy altogether and use numba
to JIT-compile your own code before it is run.
import numba
@numba.njit
def nba(array):
n = len(array)
ans = 0
for i in range(0, n):
p1 = array[i]
for j in range(i, n):
p2 = array[j]
if p1 ^ p2 > p1 & p2:
ans +=1
return ans
Finally, here's my implementation of Dave's algorithm:
from collections import defaultdict
def new_alg(array):
msb_num_count = defaultdict(int)
for num in array:
msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
# Now, for each number, the count will be the sum of the numbers in all other groups
cnt = 0
len_all_groups = len(array)
for group_len in msb_num_count.values():
cnt += group_len * (len_all_groups - group_len)
return cnt // 2
And, as a numba-compatible function. I needed to define a get_msb
since numba.njit
won't handle builtin python functions
@numba.njit
def get_msb(num):
msb = 0
while num:
msb += 1
num = num >> 1
return msb
@numba.njit
def new_alg_numba(array):
msb_num_count = {}
for num in array:
msb = get_msb(num)
if msb not in msb_num_count:
msb_num_count[msb] = 0
msb_num_count[msb] += 1
# Now, for each number, the count will be the sum of the numbers in all other groups
cnt = 0
len_all_groups = len(array)
for grp_len in msb_num_count.values():
cnt += grp_len * (len_all_groups - grp_len)
return cnt // 2
Comparing runtimes, we see that the numba approach is significantly faster than the numpy approach, which is itself faster than looping in python.
The linear-time algorithm given by Dave is faster than the numpy approach to begin with, and it starts to get faster than the numba-compiled code for inputs > ~1000 elements. The numba-compiled version of this approach is even faster -- it outpaces the numba-compiled loopy
at ~100 elements.
Kelly's excellent implementation of Dave's algorithm is on par with the numba-version of my implementation for larger inputs
(Your implementation is labelled "loopy". Other legend labels in the plot are the same as function names in my answer above. Kelly's implementation is labelled "kelly")