pythonnumpyvectorizationset-operations

Numpy: find row-wise common element efficiently


Suppose we are given two 2D numpy arrays a and b with the same number of rows. Assume furthermore that we know that each row i of a and b has at most one element in common, though this element may occur multiple times. How can we find this element as efficiently as possible?

An example:

import numpy as np

a = np.array([[1, 2, 3],
              [2, 5, 2],
              [5, 4, 4],
              [2, 1, 3]])

b = np.array([[4, 5],
              [3, 2],
              [1, 5],
              [0, 5]])

desiredResult = np.array([[np.nan],
                          [2],
                          [5],
                          [np.nan]])

It is easy to come up with a streightforward implementation by applying intersect1d along the first axis:

from intertools import starmap

desiredResult = np.array(list(starmap(np.intersect1d, zip(a, b))))

Apperently, using python's builtin set operations is even quicker. Converting the result to the desired form is easy.

However, I need an implementation as efficient as possible. Hence, I do not like the starmap, as I suppose that it requires a python call for every row. I would like a purely vectorized option, and would be happy, if this even exploitet our additional knowledge that there is at most one common value per row.

Does anyone have ideas how I could speed up the task and implement the solution more elegantly? I would be okay with using C code or cython, but coding effort should be not too much.


Solution

  • Approach #1

    Here's a vectorized one based on searchsorted2d -

    # Sort each row of a and b in-place
    a.sort(1)
    b.sort(1)
    
    # Use 2D searchsorted row-wise between a and b
    idx = searchsorted2d(a,b)
    
    # "Clip-out" out of bounds indices
    idx[idx==a.shape[1]] = 0
    
    # Get mask of valid ones i.e. matches
    mask = np.take_along_axis(a,idx,axis=1)==b
    
    # Use argmax to get first match as we know there's at most one match
    match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)
    
    # Finally use np.where to choose between valid match 
    # (decided by any one True in each row of mask)
    out = np.where(mask.any(1)[:,None],match_val,np.nan)
    

    Approach #2

    Numba-based one for memory efficiency -

    from numba import njit
    
    @njit(parallel=True)
    def numba_f1(a,b,out):
        n,a_ncols = a.shape
        b_ncols = b.shape[1]
        for i in range(n):
            for j in range(a_ncols):
                for k in range(b_ncols):
                    m = a[i,j]==b[i,k]
                    if m:
                        break
                if m:
                    out[i] = a[i,j]
                    break
        return out
    
    def find_first_common_elem_per_row(a,b):
        out = np.full(len(a),np.nan)
        numba_f1(a,b,out)
        return out
    

    Approach #3

    Here's another vectorized one based on stacking and sorting -

    r = np.arange(len(a))
    ab = np.hstack((a,b))
    idx = ab.argsort(1)
    ab_s = ab[r[:,None],idx]
    m = ab_s[:,:-1] == ab_s[:,1:]
    m2 = (idx[:,1:]*m)>=a.shape[1]
    m3 = m & m2
    out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)
    

    Approach #4

    For an elegant one, we can make use of broadcasting for a resource-hungry method -

    m = (a[:,None]==b[:,:,None]).any(2)
    out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)