pythonscipyjitjaxkdtree

Jax jitting of kd-tree code taking an intractably long amount of time


I've written myself into a corner with the following situation:

The problem comes with jit-compiling the kd-tree code. I've written it in the 'standard way', using objects for the tree nodes with left and right node fields for the children. At the leaf nodes, these fields have None values to signify the absense of children.

The code runs and is functionally correct, however jit-compiling it takes a long time: 72 seconds for a tree of 64 coordinates, 131 seconds for 343 coordinates, ... and my intended dataset has over 14 million points. I think internally Jax is tracing every single possible path through the tree, which is why it's taking so long. The results are that it's blazingly quick: 0.0075s for kd-tree 10-point retrieval vs 0.4s for a brute force search over all of the points (for 343 points). These are the kind of speeds I'm hoping to obtain for use in the optimiser (without jitting it will be too slow). However it doesn't seem possible if the compilation times are going to continue to grow as experienced.

I thought that the problem might lie in the structure of the tree, with lots of different objects to be stored, so have also implemented a kd-tree search algorithm where the tree is represented by a set of Jax-numpy arrays (e.g. coord, value, left and right; where each index corresponds to a point in the tree) and iteration rather than recursion is used to do the tree search (this was a challenge but it works!). However, converting this to work with jit (changing if-statements for jax.lax.cond) is going to be complicated, and before I start I was wondering if it's going to be worth it - surely I'll have the same problem: Jax will trace all branches of the tree until the 'null terminators' (-1 values in the left and right arrays) are reached, and it will still take a very long time to compile. I've been investigating structures like jax.lax.while_loop, in case they might help?

(I've also written a hybrid of the two approaches, with an array-based tree and a recursion-based algorithm. In this case the tracing goes into an infinite loop, I think because of the fact that the null-terminator is -1 rather than None. But the arrays should be known statically (they don't change after construction, and belong to an object which is marked as a static input), so maybe the solution lies in this and I'm doing something wrong.)

I was wondering if I'm doing anything which is obviously wrong (or if my understanding is wrong), and if there is anything I can do to speed it up? Is it just to be expected that the compile time would be so high when there are so many code paths to trace? I don't suppose I could even build the jitted function only once and then save it?

I'm concerned that the only solution may be to rewrite the optimiser code so that it doesn't use Jax (e.g. if I hard-code the derivatives, and rewrite some of the code so that it operates on arrays directly instead of being vectorised across the inputs).

The code is available here: https://github.com/FluffyCodeMonster/jax_kd_tree

All three varieties described are given: the node-based tree with recursion, the array-based tree with iteration, and the array-based tree with recursion. The former works, but is very slow to jit compile as the number of points in the tree increases; the second also works, but is not written in a jit-able way yet. The last is written to be jitted, but can't jit compile as it gets into an infinite recursion.

I really need to get this working urgently so that I can obtain the optimisation results.


Solution

  • All python-level control flow, including if statements, for and while loops, and recursion, is traced in full and flattened into a linear set of commands that is then sent to the compiler. If you are attempting a tree traversal via Python-level control flow, you're going to end up with very large programs that take a very long time to compile. This issue is discussed broadly at JAX sharp bits: control flow.

    If you want to traverse a KD tree under JIT without the long compilation, you'll have to use an iterative approach with XLA control-flow operators such as jax.lax.fori_loop and jax.lax.while_loop.

    Alternatively, you might think about instead using jax.pure_callback in order to run neighbors queries using scipy on the host. There is some discussion of this at Exploring pure_callback. It's not super efficient—each call will incur some host synchronization and data movement overhead—but it can be a pretty effective solution for things like this, particularly if you're running on CPU.