pythongpulargenumbergmpy

Preventing overflow of large integers in (GPU) optimized methods such as gmpy2 and numba


I am trying to check whether a large integer is a perfect square using gmpy2 in a JIT-decorated (optimized) routine using numba. The example here is for illustrative purposes only (from a theoretical point of view, such equations or elliptic curves can be treated differently/better). My code seems to overflow since it yields solutions that aren't really ones:

import numpy as np
from numba import jit
import gmpy2
from gmpy2 import mpz, xmpz

import time
import sys

@jit('void(uint64)')
def findIntegerSolutionsGmpy2(limit: np.uint64):
    for x in np.arange(0, limit+1, dtype=np.uint64):
        y = mpz(x**6-4*x**2+4)
        if gmpy2.is_square(y):
            print([x,gmpy2.sqrt(y),y])

def main() -> int:
    limit = 100000000
    start = time.time()
    findIntegerSolutionsGmpy2(limit)
    end = time.time()
    print("Time elapsed: {0}".format(end - start))
    return 0

if __name__ == '__main__':
    sys.exit(main())

Using a limit = 1000000000 the routine finishes within approx. 4 seconds. The limit, which I am handing over to the decorated function, will not exceed an unsigned integer of 64 Bit (which seems not to be an issue here).

I read that big integers do not work in combination with numba's JIT optimization (see for example here).

My Question: Is there any possibility to use large integers in (GPU) optimized code?


Solution

  • Real reason of wrong results is simple, you forgot to convert x to mpz, so statement x ** 6 - 4 * x ** 2 + 4 is promoted to np.uint64 type and computed with overflow (because x in statement is np.uint64). Fix is trivial, just add x = mpz(x):

    @jit('void(uint64)', forceobj = True)
    def findIntegerSolutionsGmpy2(limit: np.uint64):
        for x in np.arange(0, limit+1, dtype=np.uint64):
            x = mpz(x)
            y = mpz(x**6-4*x**2+4)
            if gmpy2.is_square(y):
                print([x,gmpy2.sqrt(y),y])
    

    also in you may notice that I added forceobj = True, this is to suppress Numba compilation warnings at start.

    After this fix everything works fine and you don't see wrong results.

    If your task is to check if expression gives strict square then I decided to invent and implement another solution for you, code below.

    It works as following. You may notice that if a number is square then it is also square modulus any number (taking modulus is x % N operation).

    We can take any number, for example product of some primes, K = 2 * 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19. Now we can make a simple filter, compute all squares modulo K, mark this squares inside bit vector and then check what numbers modulo K have ones in this filter bit vector.

    Filter K (product of primes), mentioned above, leaves only 1% of candidates for squares. We can also do a second stage, apply same filter with other primes, e.g. K2 = 23 * 29 * 31 * 37 * 41. This will filter them even mor by 3%. In total we will have 1% * 3% = 0.03% amount remaining of initial candidates.

    After two filterings only few numbers remain to be checked. They can be easily fast-checked with gmpy2.is_square().

    Filtering stage can be easily wrapped into Numba function, as I did below, this function can have extra Numba param parallel = True, this will tell Numba to automatically run all Numpy operations in parallel on all CPU cores.

    In code I use limit = 1 << 30, this signifies limit of all x to be checked, and I use block = 1 << 26, this signifies how many numbers to check at a time, in parallel Numba function. If you have enough memory you may set block to be larger to occupy all CPU cores more efficiently. block of size 1 << 26 approximately uses around 1 GB of memory.

    After using my idea with filtering and using multi-core CPU my code solves same task as yours hundred times faster.

    Try it online!

    import numpy as np, numba
    
    @numba.njit('u8[:](u8[:], u8, u8, u1[:])', cache = True, parallel = True)
    def do_filt(x, i, K, filt):
        x += i; x %= K
        x2 = x
        x2 *= x2;     x2 %= K
        x6 = x2 * x2; x6 %= K
        x6 *= x2;     x6 %= K
        x6 += np.uint64(4 * K + 4)
        x2 <<= np.uint64(2)
        x6 -= x2; x6 %= K
        y = x6
        #del x2
        filt_y = filt[y]
        filt_y_i = np.flatnonzero(filt_y).astype(np.uint64)
        return filt_y_i
    
    def main():
        import math
        gmpy2 = None
        import gmpy2
        
        Int = lambda x: (int(x) if gmpy2 is None else gmpy2.mpz(x))
        IsSquare = lambda x: gmpy2.is_square(x)
        Sqrt = lambda x: Int(gmpy2.sqrt(x))
        
        Ks = [2 * 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19,    23 * 29 * 31 * 37 * 41]
        filts = []
        for i, K in enumerate(Ks):
            a = np.arange(K, dtype = np.uint64)
            a *= a
            a %= K
            filts.append((K, np.zeros((K,), dtype = np.uint8)))
            filts[-1][1][a] = 1
            print(f'filter {i} ratio', round(len(np.flatnonzero(filts[-1][1])) / K, 4))
        
        limit = 1 << 30
        block = 1 << 26
        
        for i in range(0, limit, block):
            print(f'i block {i // block:>3} (2^{math.log2(i + 1):>6.03f})')
            x = np.arange(0, min(block, limit - i), dtype = np.uint64)
            
            for ifilt, (K, filt) in enumerate(filts):
                len_before = len(x)
                x = do_filt(x, i, K, filt)
                print(f'squares filtered by filter {ifilt}:', round(len(x) / len_before, 4))
            
            x_to_check = x
            print(f'remain to check {len(x_to_check)}')
            
            sq_x = []
            for x0 in x_to_check:
                x = Int(i + x0)
                y = x ** 6 - 4 * x ** 2 + 4
                if not IsSquare(y):
                    continue
                yr = Sqrt(y)
                assert yr * yr == y
                sq_x.append((int(x), int(yr)))
            print('squares found', len(sq_x))
            print(sq_x)
            
            del x
    
    if __name__ == '__main__':
        main()
    

    Output:

    filter 0 ratio 0.0094
    filter 1 ratio 0.0366
    i block   0 (2^ 0.000)
    squares filtered by filter 0: 0.0211
    squares filtered by filter 1: 0.039
    remain to check 13803
    squares found 2
    [(0, 2), (1, 1)]
    i block   1 (2^24.000)
    squares filtered by filter 0: 0.0211
    squares filtered by filter 1: 0.0392
    remain to check 13880
    squares found 0
    []
    i block   2 (2^25.000)
    squares filtered by filter 0: 0.0211
    squares filtered by filter 1: 0.0391
    remain to check 13835
    squares found 0
    []
    i block   3 (2^25.585)
    squares filtered by filter 0: 0.0211
    squares filtered by filter 1: 0.0393
    remain to check 13907
    squares found 0
    []
    
    ...............................