pythonnumpytreemonte-carlo-tree-search

What is most efficient way to access nodes of a tree stored in a NumPy array


Imagine we have a tree of values stored in a NumPy array. For example -

In [1]: import numpy as np

In [2]: tree = np.array([[0, 6], [0, 4], [1, 3], [2, 9], [3, 1], [2, 7]]);

In [3]: tree.shape
Out[3]: (6, 2)

Each node in the tree is row in the array. The first row tree[0] is the root node [0, 6]. The first column tree[:,0] contains the row number of the node's parent and the second column tree[:,1] contains the node's value attribute.

What is most efficient way to access the value attributes of a given node up to root via its ancestors? For example, for the sixth node [2, 7], this would be [7, 3, 4, 6]

One method is to recursively read the array from the starting node up using the first column as an index for the next ancestor, for example -

In [20]: i = 5
    ...: values = []
    ...: while i > 0:
    ...:     values.append(tree[i, 1])
    ...:     i = tree[i, 0]
    ...: values.append(tree[0, 1])
    ...: print(values)
[7, 3, 4, 6]

but I found this to be slow for large complex trees. Is there a faster way?

Background to my question - I am trying to implement the Monte Carlo tree search (MCTS)


Solution

  • For iterative operation like this, Numpy does not provide any (efficient) vectorization functions. A solution to speed this up is to use Numba (a JIT compiler) and return a Numpy array (since Numba can operate more efficiently on them). Here is an example:

    import numba as nb
    import numpy as np
    
    @nb.njit(['(int16[:,:], int_)', '(int32[:,:], int_)', '(int64[:,:], int_)'])
    def compute(tree, i):
        values = np.empty(max(tree.shape[0], 1), dtype=tree.dtype)
        cur = 0
        while i > 0:
            assert cur < values.size
            values[cur] = tree[i, 1]
            i = tree[i, 0]
            cur += 1
        assert cur < values.size
        values[cur] = tree[0, 1]
        return values[:cur+1] # Consider using copy() if cur << tree.shape[0]
    
    print(compute(tree, 5))
    

    It takes 0.76 us on my machine as opposed to 1.36 us for the initial code. However, ~0.54 us are spent in calling the JIT and checking the input parameter and 0.1~0.2 us are spent in the allocation of the output array. Thus, basically 90% of the time of the Numba function is a constant overhead. It should be much faster for large trees. If you have many small trees to compute, then you can call it from a Numba function so to avoid the overhead of calling a JIT function from the slow CPython interpreter. When called from a JIT function, the above function takes only 0.063 us on the input example. Thus, the Numba function can be up to 22 times faster in this case.

    Note that it is better to use a small datatype for the tree since random accesses are expensive in large arrays. The smaller the array in memory, the more likely it can fit in CPU caches, the faster the computation. For trees with less than 65536 items, it is safe to use a uint16 datatype (while the default one is int32 on Windows and int64 on Linux, that is, respectively 2 and 4 times bigger).