pythonperformancememorysieve-algorithm

Making Sieve of Eratosthenes more memory efficient in python?


Sieve of Eratosthenes memory constraint issue

Im currently trying to implement a version of the sieve of eratosthenes for a Kattis problem, however, I am running into some memory constraints that my implementation wont pass.

Here is a link to the problem statement. In short the problem wants me to first return the amount of primes less or equal to n and then solve for a certain number of queries if a number i is a prime or not. There is a constraint of 50 MB memory usage as well as only using the standard libraries of python (no numpy etc). The memory constraint is where I am stuck.

Here is my code so far:

import sys

def sieve_of_eratosthenes(xs, n):
    count = len(xs) + 1
    p = 3 # start at three
    index = 0
    while p*p < n:
        for i in range(index + p, len(xs), p):
            if xs[i]:
                xs[i] = 0
                count -= 1

        temp_index = index
        for i in range(index + 1, len(xs)):
            if xs[i]:
                p = xs[i]
                temp_index += 1
                break
            temp_index += 1
        index = temp_index

    return count


def isPrime(xs, a):
    if a == 1:
        return False
    if a == 2:
        return True
    if not (a & 1):
        return False
    return bool(xs[(a >> 1) - 1])

def main():
    n, q = map(int, sys.stdin.readline().split(' '))
    odds = [num for num in range(2, n+1) if (num & 1)]
    print(sieve_of_eratosthenes(odds, n))

    for _ in range(q):
        query = int(input())
        if isPrime(odds, query):
            print('1')
        else:
            print('0')


if __name__ == "__main__":
    main()

I've done some improvements so far, like only keeping a list of all odd numbers which halves the memory usage. I am also certain that the code works as intended when calculating the primes (not getting the wrong answer). My question is now, how can I make my code even more memory efficient? Should I use some other data structures? Replace my list of integers with booleans? Bitarray?

Any advice is much appreciated!

EDIT

After some tweaking to the code in python I hit a wall where my implementation of a segmented sieve would not pass the memory requirements.

Instead, I chose to implement the solution in Java, which took very little effort. Here is the code:

  public int sieveOfEratosthenes(int n){
    sieve = new BitSet((n+1) / 2);
    int count = (n + 1) / 2;

    for (int i=3; i*i <= n; i += 2){
      if (isComposite(i)) {
        continue;
      }

      // Increment by two, skipping all even numbers
      for (int c = i * i; c <= n; c += 2 * i){
        if(!isComposite(c)){
          setComposite(c);
          count--;
        }
      }
    }

    return count;

  }

  public boolean isComposite(int k) {
    return sieve.get((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public void setComposite(int k) {
    sieve.set((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public boolean isPrime(int a) {
    if (a < 3)
      return a > 1;

    if (a == 2)
      return true;

    if ((a & 1) == 1)
      return !isComposite(a);
    else
      return false;

  }

  public void run() throws Exception{
    BufferedReader scan = new BufferedReader(new InputStreamReader(System.in));
    String[] line = scan.readLine().split(" ");

    int n = Integer.parseInt(line[0]); int q = Integer.parseInt(line[1]);
    System.out.println(sieveOfEratosthenes(n));

    for (int i=0; i < q; i++){
      line = scan.readLine().split(" ");
      System.out.println( isPrime(Integer.parseInt(line[0])) ? '1' : '0');
    }
  }

I Have personally not found a way to implement this BitSet solution in Python (using only the standard library).

If anyone stumbles across a neat implementation to the problem in python, using a segmented sieve, bitarray or something else, I would be interested to see the solution.


Solution

  • This is a very challenging problem indeed. With a maximum possible N of 10^8, using one byte per value results in almost 100 MB of data assuming no overhead whatsoever. Even halving the data by only storing odd numbers will put you very close to 50 MB after overhead is considered.

    This means the solution will have to make use of one or more of a few strategies:

    1. Using a more efficient data type for our array of primality flags. Python lists maintain an array of pointers to each list item (4 bytes each on a 64 bit python). We effectively need raw binary storage, which pretty much only leaves bytearray in standard python.
    2. Using only one bit per value in the sieve instead of an entire byte (Bool technically only needs one bit, but typically uses a full byte).
    3. Sub-dividing to remove even numbers, and possibly also multiples of 3, 5, 7 etc.
    4. Using a segmented sieve

    I initially tried to solve the problem by storing only 1 bit per value in the sieve, and while the memory usage was indeed within the requirements, Python's slow bit manipulation pushed the execution time far too long. It also was rather difficult to figure out the complex indexing to make sure the correct bits were being counted reliably.

    I then implemented the odd numbers only solution using a bytearray and while it was quite a bit faster, the memory was still an issue.

    Bytearray odd numbers implementation:

    class Sieve:
        def __init__(self, n):
            self.not_prime = bytearray(n+1)
            self.not_prime[0] = self.not_prime[1] = 1
            for i in range(2, int(n**.5)+1):
                if self.not_prime[i] == 0:
                    self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
            self.n_prime = n + 1 - sum(self.not_prime)
            
        def is_prime(self, n):
            return int(not self.not_prime[n])
            
    
    
    def main():
        n, q = map(int, input().split())
        s = Sieve(n)
        print(s.n_prime)
        for _ in range(q):
            i = int(input())
            print(s.is_prime(i))
    
    if __name__ == "__main__":
        main()
    

    Further reduction in memory from this should* make it work.

    EDIT: also removing multiples of 2 and 3 did not seem to be enough memory reduction even though guppy.hpy().heap() seemed to suggest my usage was in fact a bit under 50MB. 🤷‍♂️