pythonmemorydata-structuresspace-complexitypython-internals

How can I store ids in Python without paying the 28-byte-per-int price?


My Python code stores millions of ids in various data structures, in order to implement a classic algorithm. The run time is good, but the memory usage is awful.

These ids are ints. I assume that since Python ints start at 28 bytes and grow, there's a huge price there. Since they're just opaque ids, not actually mathematical object, I could get by with just 4 bytes for them.

Is there a way to store ids in Python that won't use the full 28 bytes? E.g., do I need to put them as both keys and values to dicts?

Note: The common solution of using something like BumPy won't work here, because it's not a contiguous array. It's keys and values into a dict, and dicts of dicts, etc.

I'm also amenable to other Python interpreters that are less memory hungry for ints.


Solution

  • Your use case is for IDs to be stored as keys and values of a dict. But since keys and values of a dict have to be Python objects, they must each be allocated an object header as well as a pointer from the dict.

    To be able to actually store keys and values at 4 bytes each you would have to implement a custom hash table that allocates an array.array of 32-bit integers for both keys and values. Since IDs are typically never going to be 0 or 2**32-1, you can use them as sentinels for an empty slot and a deleted slot, respectively.

    Below is a sample implementation with linear probing:

    from array import array
    
    class HashTable:
        EMPTY = 0
        DELETED = (1 << 31) - 1
    
        def __init__(self, source=None, size=8, load_factor_threshold=0.75):
            self._size = size
            self._load_factor_threshold = load_factor_threshold
            self._count = 0
            self._keys = array('L', [self.EMPTY]) * size
            self._values = array('L', [self.EMPTY]) * size
            if source is not None:
                self.update(source)
    
        def _probe(self, key):
            index = hash(key) % self._size
            for _ in range(self._size):
                yield index, self._keys[index], self._values[index]
                index = (index + 1) % self._size
    
        def __setitem__(self, key, value):
            while self._count >= self._load_factor_threshold * self._size:
                new = HashTable(self, self._size * 2, self._load_factor_threshold)
                self._size = new._size
                self._keys = new._keys
                self._values = new._values
            for index, probed_key, probed_value in self._probe(key):
                if probed_value == self.DELETED:
                    continue
                if probed_value == self.EMPTY:
                    self._keys[index] = key
                    self._values[index] = value
                    self._count += 1
                    return
                elif probed_key == key:
                    self._values[index] = value
                    return
    
        def __getitem__(self, key):
            for _, probed_key, value in self._probe(key):
                if value == self.EMPTY:
                    break
                if value == self.DELETED:
                    continue
                if probed_key == key:
                    return value
            raise KeyError(key)
    
        def __delitem__(self, key):
            for index, probed_key, value in self._probe(key):
                if value == self.EMPTY:
                    raise KeyError(key)
                if value == self.DELETED:
                    continue
                if probed_key == key:
                    self._values[index] = self.DELETED
                    self._count -= 1
                    return
    
        def items(self):
            for key, value in zip(self._keys, self._values):
                if value not in (self.EMPTY, self.DELETED):
                    yield key, value
    
        def keys(self):
            for key, _ in self.items():
                yield key
    
        def values(self):
            for _, value in self.items():
                yield value
    
        def __iter__(self):
            yield from self.keys()
    
        def __len__(self):
            return self._count
    
        def __eq__(self, other):
            return set(self.items()) == set(other.items())
    
        def __contains__(self, key):
            try:
                self[key]
            except KeyError:
                return False
            return True
    
        def get(self, key, default=None):
            try:
                return self[key]
            except KeyError:
                return default
    
        def __repr__(self):
            return repr(dict(self.items()))
    
        def __str__(self):
            return repr(self)
    
        def copy(self):
            return HashTable(self, self._size, self._load_factor_threshold)
    
        def update(self, other):
            for key, value in other.items():
                self[key] = value
    

    so that with pympler.asizeof, which recursively measures the memory footprint of an object, you can see the memory saving to be as much as 90%:

    from pympler.asizeof import asizeof
    
    d = dict(zip(range(1500000), range(1500000)))
    h = HashTable(d)
    print(asizeof(d)) # 179877936
    print(asizeof(h)) # 16777920
    

    Note that on some platforms the type code 'L' for array.array results in an item size of 8 bytes instead of 4 bytes, in which case you should use the type code 'I' instead.