Suppose I have N
items and a multi-hot vector of values {0, 1}
that represents inclusion of these items in a result:
N = 4
# items 1 and 3 will be included in the result
vector = [0, 1, 0, 1]
# item 2 will be included in the result
vector = [0, 0, 1, 0]
I'm also provided a matrix of conflicts which indicates which items cannot be included in the result at the same time:
conflicts = [
[0, 1, 1, 0], # any result that contains items 1 AND 2 is invalid
[0, 1, 1, 1], # any result that contains AT LEAST 2 items from {1, 2, 3} is invalid
]
Given this matrix of conflicts, we can determine the validity of the earlier vector
s:
# invalid as it triggers conflict 1: [0, 1, 1, 1]
vector = [0, 1, 0, 1]
# valid as it triggers no conflicts
vector = [0, 0, 1, 0]
A naive solution to detect whether a given vector
is "valid" (i.e. does not trigger any conflicts) may be done via a dot product and summation operation in numpy:
violation = np.dot(conflicts, vector)
is_valid = np.max(violation) <= 1
Are there are more efficient ways to perform this operation, perhaps either via np.einsum
or by bypassing numpy arrays entirely in favour of bit manipulation?
We assume that the number of vectors being checked can be very large (e.g. up to 2^N
if we evaluate all possibilities) but that only one vector is likely being checked at a time (to avoid generating a matrix of shape up to (2^N, N)
as input).
TL;DR: you can use Numba to optimize np.dot
to only operate only on binary values. More specifically, you can perform SIMD-like operations on 8 bytes at once using 64-bit views.
First of all, the lists can be efficiently converted to relatively-compact arrays using this approach:
vector = np.fromiter(vector, np.uint8)
conflicts = np.array([np.fromiter(conflicts[i], np.uint8) for i in range(len(conflicts))])
This is faster than using the automatic Numpy conversion or np.array
(there is less check to perform in the Numpy code internally and Numpy, Numpy know what type of array to build and the resulting one is smaller in memory and thus faster to fill). This step can be used to speed up your np.dot
-based solution.
If the input are already a Numpy array, then check they are of type np.uint8
or np.int8
. Otherwise, please cast them to such type using conflits = conflits.astype(np.uint8)
for example.
Then, one solution could be to use np.packbits
to pack the input binary values much as possible in an array of bits in memory, and then perform logical ANDs. But it turns out that np.packbits
is pretty slow. Thus, this solution is not a good idea in the end. In fact, any solution creating temporary arrays with a shape similar to conflicts
will be slow since writing such an array in memory is generally slower than np.dot
(which read conflicts
from memory once).
Since np.dot
is pretty well optimized, the only solution to defeat it is to use an optimized native code. Numba can be used to generate a native executable code at runtime from a Numpy-based Python code thanks to a just-in-time compiler. The idea is to perform a logical ANDs between vector
and rows of conflicts
per block. Conflict are check for each block so to stop the computation as early as possible. Blocks can be efficiently compared by groups of 8 octets by comparing the uint64 views of the two arrays (in a SIMD-friendly way).
import numba as nb
@nb.njit('bool_(uint8[::1], uint8[:,::1])')
def check_valid(vector, conflicts):
n, m = conflicts.shape
assert vector.size == m
for i in range(n):
block_size = 128 # In the range: 8,16,...,248
conflicts_row = conflicts[i,:]
gsum = 0 # Global sum of conflicts
m_limit = m // block_size * block_size
for j in range(0, m_limit, block_size):
vector_block = vector[j:j+block_size].view(np.uint64)
conflicts_block = conflicts_row[j:j+block_size].view(np.uint64)
# Matching
lsum = np.uint64(0) # 8 local sums of conflicts
for k in range(block_size//8):
lsum += vector_block[k] & conflicts_block[k]
# Trick to perform the reduction of all the bytes in lsum
lsum += lsum >> 32
lsum += lsum >> 16
lsum += lsum >> 8
gsum += lsum & 0xFF
# Check if there is a conflict
if gsum >= 2:
return False
# Remaining part
for j in range(m_limit, m):
gsum += vector[j] & conflicts_row[j]
if gsum >= 2:
return False
return True
This is about 9 times faster than np.dot
on my machine for a large conflicts
array of shape (16, 65536)
(without conflicts). The time to convert lists is not included in both cases. When there are conflicts, the provided solution is much faster since it can early stop the computation.
Theoretically, the computation should be even faster, but the Numba JIT do not succeed to vectorize the loop using SIMD instructions. That being said, it seems the same issue appears for np.dot
. If the arrays are even bigger, you can parallelize the computation of the blocks (at the expense of a slower computation if the function return False).