pythonjax

jax register_pytree_node_class and register_dataclass returns non consistent datatype: list and tuple accordingly


I am writing custom class, which is basically a wrapper around list, with custom setitem method. I would like this class participate in jax.jit code, so during that I found a following problem: during jitting List field converted to tuple. However, this is case only when using

register_pytree_node_class When use register_dataclas , then List keep being list.

I simplify example to highlight only this problem.

import jax
from jax.tree_util import register_dataclass
from jax.tree_util import register_pytree_node_class
from functools import partial
from dataclasses import dataclass
from typing import List

@partial(register_dataclass,
         data_fields=['data'],
         meta_fields=['shift'])
@dataclass
class DecoratorFlatten:
    data: List[int]
    shift: int = 5

@register_pytree_node_class
@dataclass
class CustomFlatten:
    data: List[int]
    shift: int = 5

    def tree_flatten(self):
            children = self.data
            aux_data = self.shift
            return (children, aux_data)
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        obj = object.__new__(cls)
        obj.data = children
        setattr(obj, 'shift', aux_data)
        return obj

Now let's call a simple as this function over instances of this two class:

@jax.jit
def get_value(a):
    return a.data
df = DecoratorFlatten([0,1,2])
cf = CustomFlatten([0,1,3])
get_value(df), get_value(cf)

In first case we get list as output, but in second tuple. I thought maybe this is because of my implementation of the tree_flatten method, however:

cf.tree_flatten()

Leads to ([0, 1, 3], 5) as desirable.


Solution

  • In tree_unflatten, children is a tuple, and you are assigning this directly to obj.data. If you want it to be a list, you should use obj.data = list(children).