I have two dataframes: a (~600M rows) and b (~2M rows). What is the best approach for joining b onto a, when using 1 equality condition and 2 inequality conditions on the respective columns?
I have explored the following paths so far:
Now I'm running a bit out of ideas... What would be a more efficient way to implement this?
Thank you
import numba as nb
import numpy as np
import polars as pl
import time
@nb.njit(nb.int32[:](nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:]), parallel=True)
def join_multi_ineq(a_1, a_2, a_3, b_1, b_2, b_3, b_4):
output = np.zeros(len(a_1), dtype=np.int32)
for i in nb.prange(len(a_1)):
for j in range(len(b_1) - 1, -1, -1):
if a_1[i] == b_1[j]:
if a_2[i] >= b_2[j]:
if a_3[i] >= b_3[j]:
output[i] = b_4[j]
break
return output
length_a = 5_000_000
length_b = 2_000_000
start_time = time.time()
output = join_multi_ineq(a_1=np.random.randint(1, 1_000, length_a, dtype=np.int32),
a_2=np.random.randint(1, 1_000, length_a, dtype=np.int32),
a_3=np.random.randint(1, 1_000, length_a, dtype=np.int32),
b_1=np.random.randint(1, 1_000, length_b, dtype=np.int32),
b_2=np.random.randint(1, 1_000, length_b, dtype=np.int32),
b_3=np.random.randint(1, 1_000, length_b, dtype=np.int32),
b_4=np.random.randint(1, 1_000, length_b, dtype=np.int32))
print(f"Duration: {(time.time() - start_time):.2f} seconds")
Using Numba here is a good idea since the operation is particularly expensive. That being said, the complexity of the algorithm is O(n²)
though it is not easy to do much better (without making the code much more complex). Moreover, the array b_1
, which might not fit in the L3 cache, is fully read 5_000_000 times making the code rather memory bound.
We can strongly speed up the code by building an index so not to travel the whole array b_1
, but only the values where a_1[i] == b_1[j]
. This is not enough to improve the complexity since a lot of j
values fulfil this condition. We can improve the (average) complexity by building a kind of tree for all nodes of the index but in practice, this makes the code much more complex and the time to build the tree would be so big that it actually does not worth doing that in practice. Indeed, a basic index is enough to strongly reduce the execution time on the provided random dataset (with uniformly distributed numbers). Here is the resulting code:
import numba as nb
import numpy as np
import time
length_a = 5_000_000
length_b = 2_000_000
a_1=np.random.randint(1, 1_000, length_a, dtype=np.int32)
a_2=np.random.randint(1, 1_000, length_a, dtype=np.int32)
a_3=np.random.randint(1, 1_000, length_a, dtype=np.int32)
b_1=np.random.randint(1, 1_000, length_b, dtype=np.int32)
b_2=np.random.randint(1, 1_000, length_b, dtype=np.int32)
b_3=np.random.randint(1, 1_000, length_b, dtype=np.int32)
b_4=np.random.randint(1, 1_000, length_b, dtype=np.int32)
IntList = nb.types.ListType(nb.types.int32)
@nb.njit(nb.int32[:](nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:], nb.int32[:]), parallel=True)
def join_multi_ineq_fast(a_1, a_2, a_3, b_1, b_2, b_3, b_4):
output = np.zeros(len(a_1), dtype=np.int32)
b1_indices = nb.typed.Dict.empty(key_type=nb.types.int32, value_type=IntList)
for j in range(len(b_1)):
val = b_1[j]
if val in b1_indices:
b1_indices[val].append(j)
else:
lst = nb.typed.List.empty_list(item_type=np.int32)
lst.append(j)
b1_indices[val] = lst
kmean = 0
for i in nb.prange(len(a_1)):
if a_1[i] in b1_indices:
indices = b1_indices[a_1[i]]
v2 = a_2[i]
v3 = a_3[i]
for k in range(len(indices) - 1, -1, -1):
j = indices[np.uint32(k)]
#assert a_1[i] == b_1[j]
if v2 >= b_2[j] and v3 >= b_3[j]:
output[i] = b_4[j]
break
return output
%time join_multi_ineq_fast(a_1, a_2, a_3, b_1, b_2, b_3, b_4)
Note that, in average, only 32 k
values are tested (which is reasonable enough not to build a more efficient/complicated data structure). Also please note that the result is strictly identical to the one provided by the naive implementation.
Here are results on my i5-9600KF CPU (6 cores):
Roman's code: >120.00 sec (require a HUGE amount of RAM: >16 GiB)
Naive Numba code: 24.85 sec
This implementation: 0.83 sec <-----
Thus, this implementation is about 30 times faster than the initial code.