pythonjaxkdtree

Execution of conditional branches causing errors in Jax (kd-tree implementation)


I'm writing a kd-tree in Jax, and using custom written Node objects for the tree elements. Each Node is very simple, with a single data field (for holding numeric values) and left and right fields which are references to other Nodes. A leaf Node is identified as one for which the left and right fields are None.

The code performs conditional checks on the values of left and right as part of the tree traversal process - e.g. it will only try to traverse down the left or right branch of a node's subtree if it actually exists. Doing checks like if (current_node.left is not None) (or does it have to be jax.numpy.logical_not(current_node.left is None) in Jax - I've tried both?) was fine for this, but since converting the if statements to jax.lax.cond(...) I've been getting the error AttributeError: 'NoneType' object has no attribute 'left'.

I think the situation might be like in the following minimum working example:

import jax
import jax.numpy as jnp

def my_func(val):
    return 2*val

@jax.jit
def test_fn(a):
    return jax.lax.cond(a is not None,
                lambda: my_func(a),
                lambda: 0)

print(test_fn(2))       # Prints 4
# in test_fn(), a has type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
print(test_fn(None))    # TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'
# in test_fn(), a has type <class 'NoneType'>

In this code, if the Jax cond statement were a regular if statement, my_func() wouldn't even be called when a is None, and no error would be raised. To the best of my understanding, Jax tries to trace the function, meaning that all branches are executed, and this leads to my_func() being called with None (when a is None), causing the error. I believe a similar situation is arising in my tree code, where conditional branches are being executed even though .left and /or .right are None, and a traditional if statement wouldn't lead to execution of the code branches.

Is my understanding correct, and what could I do about this issue? Strangely, the minimum working example code also has the problem when the @jax.jit decorator is omitted, suggesting that both branches are still being traced.


As a related point, is the tree structure 'baked into' the Jax/XLA code? I have noticed that when using larger trees the code takes longer to be jit-compiled, which makes me concerned that this might not be a valid approach with the very large number of points I need to represent (about 14,000,000). I would use the regular Scipy kd-tree implementation, but this isn't compatible with Jax unfortunately, and the rest of my code requires it. I might ask this as a separate question for clarity.


Solution

  • If you are using jax.lax.cond, the input must have a valid type for both branches. When a is None, the first branch is invalid because None * 2 results in an error.

    In this case, the condition a is not None is known statically, so rather than using lax.cond you can use a regular if statement:

    @jax.jit
    def test_fn(a):
      return my_func(a) if a is not None else 0