I'm writing a kd-tree in Jax, and using custom written Node objects for the tree elements. Each Node is very simple, with a single data
field (for holding numeric values) and left
and right
fields which are references to other Nodes. A leaf Node is identified as one for which the left
and right
fields are None.
The code performs conditional checks on the values of left
and right
as part of the tree traversal process - e.g. it will only try to traverse down the left or right branch of a node's subtree if it actually exists. Doing checks like if (current_node.left is not None)
(or does it have to be jax.numpy.logical_not(current_node.left is None)
in Jax - I've tried both?) was fine for this, but since converting the if
statements to jax.lax.cond(...)
I've been getting the error AttributeError: 'NoneType' object has no attribute 'left'
.
I think the situation might be like in the following minimum working example:
import jax
import jax.numpy as jnp
def my_func(val):
return 2*val
@jax.jit
def test_fn(a):
return jax.lax.cond(a is not None,
lambda: my_func(a),
lambda: 0)
print(test_fn(2)) # Prints 4
# in test_fn(), a has type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
print(test_fn(None)) # TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'
# in test_fn(), a has type <class 'NoneType'>
In this code, if the Jax cond
statement were a regular if
statement, my_func()
wouldn't even be called when a
is None, and no error would be raised. To the best of my understanding, Jax tries to trace the function, meaning that all branches are executed, and this leads to my_func()
being called with None (when a
is None), causing the error. I believe a similar situation is arising in my tree code, where conditional branches are being executed even though .left
and /or .right
are None, and a traditional if
statement wouldn't lead to execution of the code branches.
Is my understanding correct, and what could I do about this issue? Strangely, the minimum working example code also has the problem when the @jax.jit
decorator is omitted, suggesting that both branches are still being traced.
As a related point, is the tree structure 'baked into' the Jax/XLA code? I have noticed that when using larger trees the code takes longer to be jit-compiled, which makes me concerned that this might not be a valid approach with the very large number of points I need to represent (about 14,000,000). I would use the regular Scipy kd-tree implementation, but this isn't compatible with Jax unfortunately, and the rest of my code requires it. I might ask this as a separate question for clarity.
If you are using jax.lax.cond
, the input must have a valid type for both branches. When a
is None
, the first branch is invalid because None * 2
results in an error.
In this case, the condition a is not None
is known statically, so rather than using lax.cond
you can use a regular if
statement:
@jax.jit
def test_fn(a):
return my_func(a) if a is not None else 0