pythonnumbadtw

speedup dtaidistance key function with numba


The DTAIDistance package can be used to find k best matches of the input query. but it cannot be used for multi-dimensional input query. moreover, I want to find the k best matches of many input queries in one run.

I modified the DTAIDistance function so that it can be used to search subsequences of multi-dimensions of multi-queries. I use njit with parallel to speed up the process,i.e.the p_calc function which applies numba-parallel to each of the input query. but I find that the parallel calculation seems not to speed up the calculation compared to just simply looping over the input queries one by one, i.e. the calc function.

import time
from tqdm import tqdm
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=True, parallel=False)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    n_series = s1.shape[1]
    ndim = s1.shape[2]
    # s1 = np.ascontiguousarray(s1)#.shape
    # s2 = np.ascontiguousarray(s2)#.shape
    # dtw = np.full((n_series,r + 1, c + 1), np.inf,dtype=s1.dtype)  # cmath.inf
    # d = np.full((n_series), np.inf,dtype=s1.dtype)  # cmath.inf
    for i in range(psi_2b + 1):
        dtw[:, 0, i] = 0
    for i in range(psi_1b + 1):
        dtw[:, i, 0] = 0
    for nn in prange(n_series):
        print('im alive...')
        i0 = 1
        i1 = 0
        sc = 0
        ec = 0
        smaller_found = False
        ec_next = 0
        for i in range(r):
            i0 = i
            i1 = i + 1
            j_start = max(0, i - max(0, r - c) - window + 1)
            j_end = min(c, i + max(0, c - r) + window)
            if sc > j_start:
                j_start = sc
            smaller_found = False
            ec_next = i
        for j in range(j_start, j_end):
            val = 0
            tmp = ((s1[i, nn] - s2[j]) ** 2)
            # tmp = (np.abs(s1[i, nn] - s2[j, 0]))
            for nd in range(ndim):
                val += tmp[nd]
            d[nn] = val
            # d = np.sum(np.abs(s1[i] - s2[j]) )  # multi-d
            if max_step is not None and d[nn] > max_step:
                continue
            # print(i, j + 1 - skip, j - skipp, j + 1 - skipp, j - skip)
            dtw[nn, i1, j + 1] = d[nn] + min(dtw[nn, i0, j],
                                             dtw[nn, i0, j + 1] + penalty,
                                             dtw[nn, i1, j] + penalty)
            # dtw[i + 1, j + 1 - skip] = d + min(dtw[i + 1, j + 1 - skip], dtw[i + 1, j - skip])
            if dtw[nn, i1, j + 1] > max_dist:
                if not smaller_found:
                    sc = j + 1
                if j >= ec:
                    break
            else:
                smaller_found = True
                ec_next = j + 1
        ec = ec_next
    # Decide which d to return
    dtw[nn] = np.sqrt(dtw[nn])
    if psi_1e == 0 and psi_2e == 0:
        d[nn] = dtw[nn, i1, min(c, c + window - 1)]
    else:
        ir = i1
        ic = min(c, c + window - 1)
        if psi_1e != 0:
            vr = dtw[nn, ir:max(0, ir - psi_1e - 1):-1, ic]
            mir = np.argmin(vr)
            vr_mir = vr[mir]
        else:
            mir = ir
            vr_mir = inf
        if psi_2e != 0:
            vc = dtw[nn, ir, ic:max(0, ic - psi_2e - 1):-1]
            mic = np.argmin(vc)
            vc_mic = vc[mic]
        else:
            mic = ic
            vc_mic = inf
        if vr_mir < vc_mic:
            if psi_neg:
                dtw[nn, ir:ir - mir:-1, ic] = -1
            d[nn] = vr_mir
        else:
            if psi_neg:
                dtw[nn, ir, ic:ic - mic:-1] = -1
            d[nn] = vc_mic
    if max_dist and d[nn] ** 2 > max_dist:
        # if max_dist and d[nn] > max_dist:
        d[nn] = inf
return d, dtw


@njit(fastmath=True, nogil=True)  # Set "nopython" mode for best performance, equivalent to @njit
def calc(s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    dtw = np.full((r + 1, c + 1), np.inf)  # cmath.inf
    for i in range(psi_2b + 1):
        dtw[0, i] = 0
    for i in range(psi_1b + 1):
        dtw[i, 0] = 0
    i0 = 1
    i1 = 0
    sc = 0
    ec = 0
    smaller_found = False
    ec_next = 0
    for i in range(r):
        i0 = i
        i1 = i + 1
        j_start = max(0, i - max(0, r - c) - window + 1)
        j_end = min(c, i + max(0, c - r) + window)
        if sc > j_start:
            j_start = sc
        smaller_found = False
        ec_next = i
        for j in range(j_start, j_end):
            # d = (s1[i] - s2[j]) ** 2# 1-d
            d = np.sum((s1[i] - s2[j]) ** 2)  # multi-d
            # d = np.sum(np.abs(s1[i] - s2[j]) )  # multi-d
            if max_step is not None and d > max_step:
                continue
            dtw[i1, j + 1] = d + min(dtw[i0, j],
                                     dtw[i0, j + 1] + penalty,
                                     dtw[i1, j] + penalty)
            if dtw[i1, j + 1] > max_dist:
                if not smaller_found:
                    sc = j + 1
                if j >= ec:
                    break
            else:
                smaller_found = True
                ec_next = j + 1
        ec = ec_next
    # Decide which d to return
    dtw = np.sqrt(dtw)
    if psi_1e == 0 and psi_2e == 0:
        d = dtw[i1, min(c, c + window - 1)]
    else:
        ir = i1
        ic = min(c, c + window - 1)
        if psi_1e != 0:
            vr = dtw[ir:max(0, ir - psi_1e - 1):-1, ic]
            mir = argmin(vr)
            vr_mir = vr[mir]
        else:
            mir = ir
            vr_mir = inf
        if psi_2e != 0:
            vc = dtw[ir, ic:max(0, ic - psi_2e - 1):-1]
            mic = argmin(vc)
            vc_mic = vc[mic]
        else:
            mic = ic
            vc_mic = inf
        if vr_mir < vc_mic:
            if psi_neg:
                dtw[ir:ir - mir:-1, ic] = -1
            d = vr_mir
        else:
            if psi_neg:
                dtw[ir, ic:ic - mic:-1] = -1
            d = vc_mic
    if max_dist and d * d > max_dist:
        d = inf
    return d, dtw


mydtype = np.float32
series1 = np.random.random((16, 30, 2)).astype(mydtype)
series2 = np.random.random((100000,  2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype)  # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype)  # cmath.inf
time1 = time.time()
d, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0,
               series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)

time1 = time.time()
for ii in tqdm(range(series1.shape[1])):
    d, dtw1 = calc( series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
                   series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)#   this one is faster

How can I speed up the calc function or p_calc function so that I can calculate the dynamic time warping paths of multi-dimensional multi-queries?

thanks for the answer,and then i modified the code for simplification. i delete the np.sum part and use loop,i can get another speedup. any suggestions for further speedups ?

import time
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=False, parallel=True)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    n_series = s1.shape[1]
    ndim = s1.shape[2]
    for nn in prange(n_series):
        for i in range(r):
            j_start = 0
            j_end = c
            for j in range(j_start, j_end):
                val = 0
                # tmp = ((s1[i, nn] - s2[j]) ** 2)
                # tmp = (np.abs(s1[i, nn] - s2[j, 0]))
                for nd in range(ndim):
                    tmp = ((s1[i, nn,nd] - s2[j,nd]) ** 2)
                    val += tmp
                d[nn] = val
    return d, dtw


@njit(fastmath=True, nogil=True)  # Set "nopython" mode for best performance, equivalent to @njit
def calc(dtw,s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
    ndim = s1.shape[-1]
    for i in range(r):
        j_start = 0
        j_end = c
        for j in range(j_start, j_end):
            d = 0
            for kk in range(ndim):
                d += (s1[i, kk] - s2[j, kk]) ** 2
    return d, dtw


mydtype = np.float32
series1 = np.random.random((16, 300, 2)).astype(mydtype)
series2 = np.random.random((1000000,  2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype)  # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype)  # cmath.inf
time1 = time.time()
# assert 1==2
# dtw[:,series2.shape[0]]
d1, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0, series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)
# assert 1==2
time1 = time.time()
dtw = np.full(( r + 1, c + 1), np.inf, dtype=mydtype)  # cmath.inf
for ii in (range(series1.shape[1])):
    d2, dtw2 = calc( dtw,series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
                   series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)#   this one is faster
np.allclose(dtw1[-1],dtw2)
np.allclose(d1[-1],d2)

EDIT: i found the following code's performance is very different if use pass or break. i don't understand why?

@njit(fastmath=True, nogil=True)
def kbest_matches(matching,k=4000):
    ki = 0
    while  ki < k:
        best_idx =np.argmin(matching)# np.argmin(np.arange(10000000))#
        if best_idx == 0 :
            # pass
            break
        ki += 1
    return 0

ss= np.random.random((1575822,))
time1 = time.time()
pp = kbest_matches(ss)
print(time.time() - time1)

Solution

  • I assume the code of both implementations are correct and as been carefully checked (otherwise the benchmark would be pointless).

    The issue likely comes from the compilation time of the function. Indeed, the first call is significantly slower than next calls, even with cache=True. This is especially important for the parallel implementation as compiling parallel Numba code is often slower (since it is more complex). The best solution to avoid this is to compile Numba functions ahead of time by providing types to Numba.

    Besides this, benchmarking a computation only once is usually considered as a bad practice. Good benchmarks perform multiple iterations and remove the first ones (or consider them separately). Indeed, several other problems can appear when a code is executed for the first time: CPU caches (and the TLB) are cold, the CPU frequency can change during the execution and is likely smaller when the program is just started, page faults may need to be needed, etc.

    In practice, I cannot reproduce the issue. Actually, p_calc is 3.3 times faster on my 6-core machine. When the benchmark is done in a loop of 5 iterations, the measured time of the parallel implementation is much smaller: about 13 times (which is actually suspicious for a parallel implementation using 6 threads on a 6-core machine).