pythonperformancenumba

Is it possible to speed up my set implementation?


I am trying to make a fast and space efficient set implementation for 64 bit unsigned ints. I don't want to use set() as that converts everything into Python ints that use much more space than 8 bytes per int. Here is my effort:

import numpy as np
from numba import njit

class HashSet:
    def __init__(self, capacity=1024):
        self.capacity = capacity
        self.size = 0
        self.EMPTY = np.uint64(0xFFFFFFFFFFFFFFFF)  # 2^64 - 1
        self.DELETED = np.uint64(0xFFFFFFFFFFFFFFFE)  # 2^64 - 2
        self.table = np.full(capacity, self.EMPTY)  # Initialize with a special value indicating empty

    def insert(self, key):
        if self.size >= self.capacity:
            raise RuntimeError("Hash table is full")
        if not self._insert(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
            print(f"Key already exists: {key}")
        else:
            self.size += 1

    def contains(self, key):
        return self._contains(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash)

    def remove(self, key):
        if self._remove(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
            self.size -= 1

    def __len__(self):
        return self.size

    @staticmethod
    @njit
    def _hash(key, capacity):
        return key % capacity

    @staticmethod
    @njit
    def _insert(table, key, capacity, EMPTY, DELETED, hash_func):
        index = hash_func(key, capacity)
        while table[index] != EMPTY and table[index] != DELETED and table[index] != key:
            index = (index + 1) % capacity

        if table[index] == key:
            return False  # Key already exists

        table[index] = key
        return True

    @staticmethod
    @njit
    def _contains(table, key, capacity, EMPTY, DELETED, hash_func):
        index = hash_func(key, capacity)
        while table[index] != EMPTY:
            if table[index] == key:
                return True
            index = (index + 1) % capacity
        return False

    @staticmethod
    @njit
    def _remove(table, key, capacity, EMPTY, DELETED, hash_func):
        index = hash_func(key, capacity)
        while table[index] != EMPTY:
            if table[index] == key:
                table[index] = DELETED
                return True
            index = (index + 1) % capacity
        return False

I am using numba whereever I can to speed things up but it still isn't very fast. For example:

hash_set = HashSet(capacity=204800)
keys = np.random.randint(0, 2**64, size=100000, dtype=np.uint64)
def insert_and_remove(hash_set, key):
    hash_set.insert(np.uint64(key))
    hash_set.remove(key)
%timeit insert_and_remove(hash_set, keys[0])

This gives:

16.9 μs ± 407 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

The main cause is the code that I have failed to wrap with numba I think.

How can this be sped up?

EDIT

@ken suggested defining _hash as a global function outside the class. This speeds things up so now it is only ~50% slower than set().


Solution

  • As requested, here is the class, but using jitclass. I'm not sure how much value all the type annotations add. I had been playing around to see if could get any improvements. Overall, your original code had peak performance of 20 μs. Whereas, the code below had a peak performance of 2.3 μs (an order of magnitude faster. However, using a python set was an order of magnitude faster again at 0.34 μs. These timings are only with the test harness you provided. No other performance testing was done.

    The main things I had to do to get your code working with jitclass are:

    The only bits you really need are:

    # Extra arg to decorator tells numba what the dtype of the array is expected to be
    @jitclass([('table', numba.uint64[:])])
    class HashSet:
        capacity: numba.uint64
        size: numba.uint64
        table: np.ndarray
    

    and

    index = (index + numba.uint64(1)) % self.capacity
    

    In addition I also made EMPTY and DELETED global constants. This gets you a small space saving if you have lots of small sets, but without any less in performance. With numba they truly are constants, and not just global variables.

    Code

    import numpy as np
    import numba
    from numba.experimental import jitclass
    
    EMPTY = numba.uint64(0xFFFFFFFFFFFFFFFF)  # 2^64 - 1
    DELETED = numba.uint64(0xFFFFFFFFFFFFFFFE)  # 2^64 - 2
    
    
    @jitclass([('table', numba.uint64[:])])
    class HashSet:
        capacity: numba.uint64
        size: numba.uint64
        table: np.ndarray
    
        def __init__(self, capacity: int = 1024) -> None:
            self.capacity = capacity
            self.size = 0
            self.table = np.full(self.capacity, EMPTY) # Initialize with a special value indicating empty
    
        def __len__(self) -> int:
            return self.size
    
        @staticmethod
        def _hash(key: numba.uint64, capacity: numba.uint64) -> numba.uint64:
            return key % capacity
    
        def insert(self, key: numba.uint64) -> bool:
            if self.size >= self.capacity:
                raise RuntimeError("Hash table is full")
    
            index = self._hash(key, self.capacity)
            while self.table[index] != EMPTY and self.table[index] != DELETED and self.table[index] != key:
                index = (index + numba.uint64(1)) % self.capacity
    
            if self.table[index] == key:
                return False  # Key already exists
    
            self.table[index] = key
            self.size += 1
            return True
    
        def contains(self, key: numba.uint64) -> bool:
            index = self._hash(key, self.capacity)
            while self.table[index] != EMPTY:
                if self.table[index] == key:
                    return True
                index = (index + numba.uint64(1)) % self.capacity
            return False
    
        def remove(self, key: numba.uint64) -> bool:
            index = self._hash(key, self.capacity)
            while self.table[index] != EMPTY:
                if self.table[index] == key:
                    self.table[index] = DELETED
                    self.size -= 1
                    return True
                index = (index + numba.uint64(1)) % self.capacity
            return False