Im currently trying to implement a version of the sieve of eratosthenes for a Kattis problem, however, I am running into some memory constraints that my implementation wont pass.
Here is a link to the problem statement. In short the problem wants me to first return the amount of primes less or equal to n and then solve for a certain number of queries if a number i is a prime or not. There is a constraint of 50 MB memory usage as well as only using the standard libraries of python (no numpy etc). The memory constraint is where I am stuck.
Here is my code so far:
import sys
def sieve_of_eratosthenes(xs, n):
count = len(xs) + 1
p = 3 # start at three
index = 0
while p*p < n:
for i in range(index + p, len(xs), p):
if xs[i]:
xs[i] = 0
count -= 1
temp_index = index
for i in range(index + 1, len(xs)):
if xs[i]:
p = xs[i]
temp_index += 1
break
temp_index += 1
index = temp_index
return count
def isPrime(xs, a):
if a == 1:
return False
if a == 2:
return True
if not (a & 1):
return False
return bool(xs[(a >> 1) - 1])
def main():
n, q = map(int, sys.stdin.readline().split(' '))
odds = [num for num in range(2, n+1) if (num & 1)]
print(sieve_of_eratosthenes(odds, n))
for _ in range(q):
query = int(input())
if isPrime(odds, query):
print('1')
else:
print('0')
if __name__ == "__main__":
main()
I've done some improvements so far, like only keeping a list of all odd numbers which halves the memory usage. I am also certain that the code works as intended when calculating the primes (not getting the wrong answer). My question is now, how can I make my code even more memory efficient? Should I use some other data structures? Replace my list of integers with booleans? Bitarray?
Any advice is much appreciated!
After some tweaking to the code in python I hit a wall where my implementation of a segmented sieve would not pass the memory requirements.
Instead, I chose to implement the solution in Java, which took very little effort. Here is the code:
public int sieveOfEratosthenes(int n){
sieve = new BitSet((n+1) / 2);
int count = (n + 1) / 2;
for (int i=3; i*i <= n; i += 2){
if (isComposite(i)) {
continue;
}
// Increment by two, skipping all even numbers
for (int c = i * i; c <= n; c += 2 * i){
if(!isComposite(c)){
setComposite(c);
count--;
}
}
}
return count;
}
public boolean isComposite(int k) {
return sieve.get((k - 3) / 2); // Since we don't keep track of even numbers
}
public void setComposite(int k) {
sieve.set((k - 3) / 2); // Since we don't keep track of even numbers
}
public boolean isPrime(int a) {
if (a < 3)
return a > 1;
if (a == 2)
return true;
if ((a & 1) == 1)
return !isComposite(a);
else
return false;
}
public void run() throws Exception{
BufferedReader scan = new BufferedReader(new InputStreamReader(System.in));
String[] line = scan.readLine().split(" ");
int n = Integer.parseInt(line[0]); int q = Integer.parseInt(line[1]);
System.out.println(sieveOfEratosthenes(n));
for (int i=0; i < q; i++){
line = scan.readLine().split(" ");
System.out.println( isPrime(Integer.parseInt(line[0])) ? '1' : '0');
}
}
I Have personally not found a way to implement this BitSet solution in Python (using only the standard library).
If anyone stumbles across a neat implementation to the problem in python, using a segmented sieve, bitarray or something else, I would be interested to see the solution.
This is a very challenging problem indeed. With a maximum possible N
of 10^8, using one byte per value results in almost 100 MB of data assuming no overhead whatsoever. Even halving the data by only storing odd numbers will put you very close to 50 MB after overhead is considered.
This means the solution will have to make use of one or more of a few strategies:
bytearray
in standard python.I initially tried to solve the problem by storing only 1 bit per value in the sieve, and while the memory usage was indeed within the requirements, Python's slow bit manipulation pushed the execution time far too long. It also was rather difficult to figure out the complex indexing to make sure the correct bits were being counted reliably.
I then implemented the odd numbers only solution using a bytearray and while it was quite a bit faster, the memory was still an issue.
Bytearray odd numbers implementation:
class Sieve:
def __init__(self, n):
self.not_prime = bytearray(n+1)
self.not_prime[0] = self.not_prime[1] = 1
for i in range(2, int(n**.5)+1):
if self.not_prime[i] == 0:
self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
self.n_prime = n + 1 - sum(self.not_prime)
def is_prime(self, n):
return int(not self.not_prime[n])
def main():
n, q = map(int, input().split())
s = Sieve(n)
print(s.n_prime)
for _ in range(q):
i = int(input())
print(s.is_prime(i))
if __name__ == "__main__":
main()
Further reduction in memory from this should* make it work.
EDIT:
also removing multiples of 2 and 3 did not seem to be enough memory reduction even though guppy.hpy().heap()
seemed to suggest my usage was in fact a bit under 50MB. 🤷‍♂️