pythonnumpybit-manipulationnumpy-einsum

Determining the validity of a multi-hot encoding


Suppose I have N items and a multi-hot vector of values {0, 1} that represents inclusion of these items in a result:

N = 4

# items 1 and 3 will be included in the result
vector = [0, 1, 0, 1]

# item 2 will be included in the result
vector = [0, 0, 1, 0]

I'm also provided a matrix of conflicts which indicates which items cannot be included in the result at the same time:

conflicts = [
  [0, 1, 1, 0], # any result that contains items 1 AND 2 is invalid
  [0, 1, 1, 1], # any result that contains AT LEAST 2 items from {1, 2, 3} is invalid
]

Given this matrix of conflicts, we can determine the validity of the earlier vectors:

# invalid as it triggers conflict 1: [0, 1, 1, 1]
vector = [0, 1, 0, 1]

# valid as it triggers no conflicts
vector = [0, 0, 1, 0]

A naive solution to detect whether a given vector is "valid" (i.e. does not trigger any conflicts) may be done via a dot product and summation operation in numpy:

violation = np.dot(conflicts, vector)
is_valid = np.max(violation) <= 1

Are there are more efficient ways to perform this operation, perhaps either via np.einsum or by bypassing numpy arrays entirely in favour of bit manipulation?

We assume that the number of vectors being checked can be very large (e.g. up to 2^N if we evaluate all possibilities) but that only one vector is likely being checked at a time (to avoid generating a matrix of shape up to (2^N, N) as input).


Solution

  • TL;DR: you can use Numba to optimize np.dot to only operate only on binary values. More specifically, you can perform SIMD-like operations on 8 bytes at once using 64-bit views.




    Converting lists to arrays

    First of all, the lists can be efficiently converted to relatively-compact arrays using this approach:

    vector = np.fromiter(vector, np.uint8)
    conflicts = np.array([np.fromiter(conflicts[i], np.uint8) for i in range(len(conflicts))])
    

    This is faster than using the automatic Numpy conversion or np.array (there is less check to perform in the Numpy code internally and Numpy, Numpy know what type of array to build and the resulting one is smaller in memory and thus faster to fill). This step can be used to speed up your np.dot-based solution.

    If the input are already a Numpy array, then check they are of type np.uint8 or np.int8. Otherwise, please cast them to such type using conflits = conflits.astype(np.uint8) for example.


    First try

    Then, one solution could be to use np.packbits to pack the input binary values much as possible in an array of bits in memory, and then perform logical ANDs. But it turns out that np.packbits is pretty slow. Thus, this solution is not a good idea in the end. In fact, any solution creating temporary arrays with a shape similar to conflicts will be slow since writing such an array in memory is generally slower than np.dot (which read conflicts from memory once).


    Using Numba

    Since np.dot is pretty well optimized, the only solution to defeat it is to use an optimized native code. Numba can be used to generate a native executable code at runtime from a Numpy-based Python code thanks to a just-in-time compiler. The idea is to perform a logical ANDs between vector and rows of conflicts per block. Conflict are check for each block so to stop the computation as early as possible. Blocks can be efficiently compared by groups of 8 octets by comparing the uint64 views of the two arrays (in a SIMD-friendly way).

    import numba as nb
    
    @nb.njit('bool_(uint8[::1], uint8[:,::1])')
    def check_valid(vector, conflicts):
        n, m = conflicts.shape
        assert vector.size == m
    
        for i in range(n):
            block_size = 128 # In the range: 8,16,...,248
            conflicts_row = conflicts[i,:]
            gsum = 0 # Global sum of conflicts
            m_limit = m // block_size * block_size
    
            for j in range(0, m_limit, block_size):
                vector_block = vector[j:j+block_size].view(np.uint64)
                conflicts_block = conflicts_row[j:j+block_size].view(np.uint64)
    
                # Matching
                lsum = np.uint64(0) # 8 local sums of conflicts
                for k in range(block_size//8):
                    lsum += vector_block[k] & conflicts_block[k]
    
                # Trick to perform the reduction of all the bytes in lsum
                lsum += lsum >> 32
                lsum += lsum >> 16
                lsum += lsum >> 8
                gsum += lsum & 0xFF
    
                # Check if there is a conflict
                if gsum >= 2:
                    return False
    
            # Remaining part
            for j in range(m_limit, m):
                gsum += vector[j] & conflicts_row[j]
    
            if gsum >= 2:
                return False
    
        return True
    

    Results

    This is about 9 times faster than np.dot on my machine for a large conflicts array of shape (16, 65536) (without conflicts). The time to convert lists is not included in both cases. When there are conflicts, the provided solution is much faster since it can early stop the computation.

    Theoretically, the computation should be even faster, but the Numba JIT do not succeed to vectorize the loop using SIMD instructions. That being said, it seems the same issue appears for np.dot. If the arrays are even bigger, you can parallelize the computation of the blocks (at the expense of a slower computation if the function return False).