pythonmathnumba

How can I avoid getting the wrong answer when calculating with njit in Python?


I once used a library called njit from numba to improve the speed of calculations in Python. Suddenly I got a wrong answer. At the same time, if I do not use njit, my code gives a correct answer.

from numba import njit


@njit
def main_with_njit():
    a = 94906267
    print(f'main_with_njit: a**2 = {a ** 2}')  # -> a**2 = 9007199515875288


main_with_njit()


def main_without_njit():
    a = 94906267
    print(f'main_without_njit: a**2 = {a ** 2}')  # -> a**2 = 9007199515875289


main_without_njit()

How can I use njit without getting a wrong answer?


Solution

  • First things first. This error is only possible by first converting the operands to 64-bit floating point number. The reason for this is that 64-bit floating point numbers do not have the precision to represent 9007199515875289, and the result is instead rounded to 9007199515875288.

    You have not done anything to suggest that you want the operands to be floating point numbers. This suggests that the behaviour is a bug somewhere down the numba / llvm chain. Indeed, if you look at the LLVM generated code you can see that the code transforms a ** 2 in to a * a. However, for some reason it converts the result to a double and then back again before returning it.

    As pointed out by @ken, this behaviour is not seen on all versions of python. Both 3.10 and 3.12 work for me. It is only 3.11 that produced the erroneous result. You can view the LLVM generated code by doing:

    import numba
    
    @numba.njit
    def f(a: int) -> int:
        return a ** 2
    
    f(94906267) # NB. code is not generated unless you call the function
    
    args = (numba.int64,)
    code = f.inspect_llvm()[args]
    

    Python 3.11 LLVM Code Snippet

      %.185.i = mul nsw i64 %arg.a, %arg.a
      %.224.i = sitofp i64 %.185.i to double
      %.229.i = fptosi double %.224.i to i64
      store i64 %.229.i, i64* %retptr, align 8
    

    You can see that it first does the multiplication, but then for some reason converts the result to a 64-bit floating point number (sitofp i64 %.185.i to double, ie. sitofp = signed integer to floating point) and then back again (fptosi double %.224.i to i64), before returning it.

    Python 3.12 LLVM Code Snippet

      %.191.i = mul nsw i64 %arg.a, %arg.a
      store i64 %.191.i, i64* %retptr, align 8