I am writing custom class, which is basically a wrapper around list, with custom setitem method. I would like this class participate in jax.jit code, so during that I found a following problem: during jitting List field converted to tuple. However, this is case only when using
register_pytree_node_class
When use register_dataclas
, then List keep being list.
I simplify example to highlight only this problem.
import jax
from jax.tree_util import register_dataclass
from jax.tree_util import register_pytree_node_class
from functools import partial
from dataclasses import dataclass
from typing import List
@partial(register_dataclass,
data_fields=['data'],
meta_fields=['shift'])
@dataclass
class DecoratorFlatten:
data: List[int]
shift: int = 5
@register_pytree_node_class
@dataclass
class CustomFlatten:
data: List[int]
shift: int = 5
def tree_flatten(self):
children = self.data
aux_data = self.shift
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
obj = object.__new__(cls)
obj.data = children
setattr(obj, 'shift', aux_data)
return obj
Now let's call a simple as this function over instances of this two class:
@jax.jit
def get_value(a):
return a.data
df = DecoratorFlatten([0,1,2])
cf = CustomFlatten([0,1,3])
get_value(df), get_value(cf)
In first case we get list as output, but in second tuple. I thought maybe this is because of my implementation of the tree_flatten method, however:
cf.tree_flatten()
Leads to ([0, 1, 3], 5)
as desirable.
In tree_unflatten
, children
is a tuple, and you are assigning this directly to obj.data
. If you want it to be a list, you should use obj.data = list(children)
.