I am trying to understand how pytrees work and registered my own class as a pytree. I noticed that if the aux_data
in the pytree is a jax.numpy.ndarray
the auxilliary data is subsequently traced and returned as a Traced<ShapedArray(...)>...
. However, if the aux_data
is a numpy.ndarray
(i.e. not JAX array), then it is not traced and returns an array from a jit tranformed function.
Now, I am aware of the tracing that happens during the jax.jit()
transformation, but I do not understand why, on the level of pytrees, this results in the behaviour described above.
Here is an example to reproduce this behaviour (multiplying both the aux_data and the tree leaves by two, which may be a problem in itself after JIT transformation...?). I have used the custom pytree implementations of accepted libraries (equinox and simple_pytree) for comparison, and they all give the same result, so that I am very sure that this is not a bug but a feature that I am trying to understand.
import jax
from jax.tree_util import tree_structure, tree_leaves
import numpy as np
def get_pytree_impl(base):
if base == "equinox":
import equinox as eqx
Module = eqx.Module
static_field = eqx.static_field
elif base == "simple_pytree":
from simple_pytree import Pytree, static_field
Module = Pytree
elif base == "dataclasses":
from dataclasses import dataclass, field
@dataclass
class Module():
pass
static_field = field
class PytreeImpl(Module):
x: jax.numpy.ndarray
y: jax.numpy.ndarray = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
if base == 'dataclasses':
from jax.tree_util import register_pytree_node
def flatten(ptree):
return ((ptree.x,), ptree.y)
def unflatten(aux_data, children):
return PytreeImpl(*children, aux_data)
register_pytree_node(PytreeImpl, flatten, unflatten)
return PytreeImpl
def times_two(ptree):
return type(ptree)(ptree.x*2, ptree.y*2)
times_two_jitted = jax.jit(times_two)
bases = ['dataclasses', 'equinox', 'simple_pytree']
for base in bases:
print("======== " + base + " ========")
for lib_name, array_lib in zip(['jnp', 'np'], [jax.numpy, np]):
print("==== " + lib_name)
PytreeImpl = get_pytree_impl(base)
x = jax.numpy.array([1,2])
y = array_lib.array([3,4])
input_tree = PytreeImpl(x, y)
for tag, pytree in zip(["input", "no_jit", "jit"],[input_tree, times_two(input_tree), times_two_jitted(input_tree)]):
print(f' {tag}:')
print(f'\t Structure: {tree_structure(pytree)}')
print(f'\t Leaves: {tree_leaves(pytree)}')
This produces the follwing, where dataclasses is my naive custom implementation of a pytree:
======== dataclasses ========
==== jnp
input:
Structure: PyTreeDef(CustomNode(PytreeImpl[[3 4]], [*]))
Leaves: [Array([1, 2], dtype=int32)]
no_jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[[6 8]], [*]))
Leaves: [Array([2, 4], dtype=int32)]
jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>], [*]))
Leaves: [Array([2, 4], dtype=int32)]
==== np
input:
Structure: PyTreeDef(CustomNode(PytreeImpl[[3 4]], [*]))
Leaves: [Array([1, 2], dtype=int32)]
no_jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[[6 8]], [*]))
Leaves: [Array([2, 4], dtype=int32)]
jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[[6 8]], [*]))
Leaves: [Array([2, 4], dtype=int32)]
======== equinox ========
==== jnp
input:
Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (Array([3, 4], dtype=int32),)], [*]))
Leaves: [Array([1, 2], dtype=int32)]
no_jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (Array([6, 8], dtype=int32),)], [*]))
Leaves: [Array([2, 4], dtype=int32)]
jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>,)], [*]))
Leaves: [Array([2, 4], dtype=int32)]
==== np
input:
Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (array([3, 4]),)], [*]))
Leaves: [Array([1, 2], dtype=int32)]
no_jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (array([6, 8]),)], [*]))
Leaves: [Array([2, 4], dtype=int32)]
jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (array([6, 8]),)], [*]))
Leaves: [Array([2, 4], dtype=int32)]
======== simple_pytree ========
==== jnp
input:
Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': Array([3, 4], dtype=int32), '_pytree__initialized': True})], [*]))
Leaves: [Array([1, 2], dtype=int32)]
no_jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': Array([6, 8], dtype=int32), '_pytree__initialized': True})], [*]))
Leaves: [Array([2, 4], dtype=int32)]
jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>, '_pytree__initialized': True})], [*]))
Leaves: [Array([2, 4], dtype=int32)]
==== np
input:
Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': array([3, 4]), '_pytree__initialized': True})], [*]))
Leaves: [Array([1, 2], dtype=int32)]
no_jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': array([6, 8]), '_pytree__initialized': True})], [*]))
Leaves: [Array([2, 4], dtype=int32)]
jit:
Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': array([6, 8]), '_pytree__initialized': True})], [*]))
Leaves: [Array([2, 4], dtype=int32)]
I ran this example using Python 3.12.1 with equinox 0.11.4 jax 0.4.28 jaxlib 0.4.28 simple-pytree 0.1.5
From the JAX docs:
When defining unflattening functions, in general
children
should contain all the dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), whileaux_data
should contain all the static elements that will be rolled into the treedef structure.
aux_data
in a pytree flattening must contain static elements, and static elements must be hashable and immutable. Neither np.ndarray
nor jax.Array
satisfy this, so they should not be included in aux_data
. If you do include such values in aux_data
, you'll get unsupported, poorly-defined behavior.
With that background: the answer to your question of why you're seeing the results you're seeing is that you are defining your pytrees incorrectly. If you define aux_data
to only contain static (i.e. hashable and immutable) attributes, you will no longer see this behavior.