pythonperformancenumba

Performance loss in numba compiled logic comparison


What could be a reason for performance degradation in the following numba compiled function for logic comparison:

from numba import njit

t = (True, 'and_', False)

#@njit(boolean(boolean, unicode_type, boolean))    
@njit
def f(a,b,c):
    if b == 'and_':
        out = a&c
    elif b == 'or_':
        out = a|c
    return out
x = f(*t)
%timeit f(*t)
#1.78 µs ± 9.52 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

%timeit f.py_func(*t)
#108 ns ± 0.0042 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

To test this at scale as suggested in the answer:

x = np.random.choice([True,False], 1000000)
y = np.random.choice(["and_","or_"], 1000000)
z = np.random.choice([False, True], 1000000)

#using jit compiled f
def f2(x,y,z):
    L = x.shape[0]
    out = np.empty(L)
    for i in range(L):
        out[i] = f(x[i],y[i],z[i])
    return out

%timeit f2(x,y,z)
#2.79 s ± 86.4 ms per loop

#using pure Python f
def f3(x,y,z):
    L = x.shape[0]
    out = np.empty(L)
    for i in range(L):
        out[i] = f.py_func(x[i],y[i],z[i])
    return out

%timeit f3(x,y,z)
#572 ms ± 24.3 ms per

Am I missing something and if there a way to compile "fast" version, because this is a going to be part of a loop executed ~ 1e6 times.


Solution

  • You are working at a too small granularity. Numba is not designed for that. Almost all the execution time you see comes from the overhead of wrapping/unwrapping parameters, type checks, Python function wrapping, reference counting, etc. Moreover the benefit of using Numba is very small here since Numba barely optimizes unicode string operations.

    One way to check this hypothesis is to just execute the following trivial function:

    @njit
    def f(a,b,c):
        return a
    x = f(True, 'and_', False)
    %timeit f(True, 'and_', False)
    

    Both the trivial function and the original version takes 1.34 µs on my machine.

    Additionally, you can disassemble the Numba function to see how much instructions are executed to perform just one call and understand deeply where the overheads are coming from.

    If you want Numba to be useful, you need to add more work in the compiled function, possibly by working directly on arrays/lists. If this is not possible because of the dynamic nature of the input type, then Numpy may not be the right tool for this here. You could try to rework a bit your code and use PyPy instead. Writing a native C/C++ module may help a bit but most of the time will be spend in manipulating dynamic objects and unicode string as well as doing type introspection, unless you rewrite the whole code.


    UPDATE

    The above overhead is only paid when transitioning from Python types to Numba (and the other way around). You can see that with the following benchmark:

    @njit
    def f(a,b,c):
        if b == 'and_':
            out = a&c
        elif b == 'or_':
            out = a|c
        return out
    @jit
    def manyCalls(a, b, c):
        res = True
        for i in range(1_000_000):
            res ^= f(a, b, c ^ res)
        return res
    
    t = (True, 'and_', False)
    x = manyCalls(*t)
    %timeit manyCalls(*t)
    

    Calling manyCalls takes 3.62 ms on my machine. This means each call to f takes 3.6 ns in average (16 cycles). This means the overhead is paid only once (when manyCalls is called).