I am trying to make a fast and space efficient set implementation for 64 bit unsigned ints. I don't want to use set() as that converts everything into Python ints that use much more space than 8 bytes per int. Here is my effort:
import numpy as np
from numba import njit
class HashSet:
def __init__(self, capacity=1024):
self.capacity = capacity
self.size = 0
self.EMPTY = np.uint64(0xFFFFFFFFFFFFFFFF) # 2^64 - 1
self.DELETED = np.uint64(0xFFFFFFFFFFFFFFFE) # 2^64 - 2
self.table = np.full(capacity, self.EMPTY) # Initialize with a special value indicating empty
def insert(self, key):
if self.size >= self.capacity:
raise RuntimeError("Hash table is full")
if not self._insert(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
print(f"Key already exists: {key}")
else:
self.size += 1
def contains(self, key):
return self._contains(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash)
def remove(self, key):
if self._remove(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
self.size -= 1
def __len__(self):
return self.size
@staticmethod
@njit
def _hash(key, capacity):
return key % capacity
@staticmethod
@njit
def _insert(table, key, capacity, EMPTY, DELETED, hash_func):
index = hash_func(key, capacity)
while table[index] != EMPTY and table[index] != DELETED and table[index] != key:
index = (index + 1) % capacity
if table[index] == key:
return False # Key already exists
table[index] = key
return True
@staticmethod
@njit
def _contains(table, key, capacity, EMPTY, DELETED, hash_func):
index = hash_func(key, capacity)
while table[index] != EMPTY:
if table[index] == key:
return True
index = (index + 1) % capacity
return False
@staticmethod
@njit
def _remove(table, key, capacity, EMPTY, DELETED, hash_func):
index = hash_func(key, capacity)
while table[index] != EMPTY:
if table[index] == key:
table[index] = DELETED
return True
index = (index + 1) % capacity
return False
I am using numba whereever I can to speed things up but it still isn't very fast. For example:
hash_set = HashSet(capacity=204800)
keys = np.random.randint(0, 2**64, size=100000, dtype=np.uint64)
def insert_and_remove(hash_set, key):
hash_set.insert(np.uint64(key))
hash_set.remove(key)
%timeit insert_and_remove(hash_set, keys[0])
This gives:
16.9 μs ± 407 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
The main cause is the code that I have failed to wrap with numba I think.
How can this be sped up?
EDIT
@ken suggested defining _hash as a global function outside the class. This speeds things up so now it is only ~50% slower than set().
As requested, here is the class, but using jitclass
. I'm not sure how much value all the type annotations add. I had been playing around to see if could get any improvements. Overall, your original code had peak performance of 20 μs. Whereas, the code below had a peak performance of 2.3 μs (an order of magnitude faster. However, using a python set
was an order of magnitude faster again at 0.34 μs. These timings are only with the test harness you provided. No other performance testing was done.
The main things I had to do to get your code working with jitclass
are:
1
to a numba.uint64
. Without this, numba would promote the type of the expression to numba.float64
, even when the local had a type annotation of uint64
. Trying to index an array with a float caused the whole compilation step to fail.njit
decorators from methods on the class. jitclass
automatically applies an njit
to all methods. jitclass
errors out if any of the classes methods have already been jit-ed.The only bits you really need are:
# Extra arg to decorator tells numba what the dtype of the array is expected to be
@jitclass([('table', numba.uint64[:])])
class HashSet:
capacity: numba.uint64
size: numba.uint64
table: np.ndarray
and
index = (index + numba.uint64(1)) % self.capacity
In addition I also made EMPTY
and DELETED
global constants. This gets you a small space saving if you have lots of small sets, but without any less in performance. With numba they truly are constants, and not just global variables.
import numpy as np
import numba
from numba.experimental import jitclass
EMPTY = numba.uint64(0xFFFFFFFFFFFFFFFF) # 2^64 - 1
DELETED = numba.uint64(0xFFFFFFFFFFFFFFFE) # 2^64 - 2
@jitclass([('table', numba.uint64[:])])
class HashSet:
capacity: numba.uint64
size: numba.uint64
table: np.ndarray
def __init__(self, capacity: int = 1024) -> None:
self.capacity = capacity
self.size = 0
self.table = np.full(self.capacity, EMPTY) # Initialize with a special value indicating empty
def __len__(self) -> int:
return self.size
@staticmethod
def _hash(key: numba.uint64, capacity: numba.uint64) -> numba.uint64:
return key % capacity
def insert(self, key: numba.uint64) -> bool:
if self.size >= self.capacity:
raise RuntimeError("Hash table is full")
index = self._hash(key, self.capacity)
while self.table[index] != EMPTY and self.table[index] != DELETED and self.table[index] != key:
index = (index + numba.uint64(1)) % self.capacity
if self.table[index] == key:
return False # Key already exists
self.table[index] = key
self.size += 1
return True
def contains(self, key: numba.uint64) -> bool:
index = self._hash(key, self.capacity)
while self.table[index] != EMPTY:
if self.table[index] == key:
return True
index = (index + numba.uint64(1)) % self.capacity
return False
def remove(self, key: numba.uint64) -> bool:
index = self._hash(key, self.capacity)
while self.table[index] != EMPTY:
if self.table[index] == key:
self.table[index] = DELETED
self.size -= 1
return True
index = (index + numba.uint64(1)) % self.capacity
return False