pythonjaxflax

Flax nnx / jax: tree.map for layers of incongruent size


I am trying to figure out how to use nnx.split_rngs. Can somebody give a version of the code below that uses nnx.split_rngs with jax.tree.map to produce an arbitrary number of Linear layers with different out_features?

import jax
from flax import nnx
from functools import partial

if __name__ == '__main__':

    session_sizes = {
        'a':2,
        'b':3,
        'c':4,
        'd':5,
        'e':6,
    }
    dz = 2

    rngs = nnx.Rngs(0)
    
    my_linear = partial(
        nnx.Linear,
        use_bias = False,
        in_features = dz,
        rngs=rngs )
    
    def my_linear_wrapper(a):
        return my_linear( out_features=a )

    q_s = jax.tree.map(my_linear_wrapper, session_sizes)

    for k in session_sizes.keys():
        print(q_s[k].kernel)

So in this case, we would need a tree of layers that will take our 2 in_features into spaces of 2, ..., 6 out_features.

The function my_linear_wrapper is sort of a workaround for the original solution we had in mind, which is to map in very much the same fashion as we're doing, but instead use (something like) the @nnx.split_rngs function decorator.

Is there a way to use nnx.split_rngs on my_linear in order to map over the rng argument to nnx.Linear?


Solution

  • split_rngs is mostly useful when you are going to pass the Rngs through a transform like vmap, here you want to produce variable sized Modules so the current solution is the way to go. Because of how partial works you can simplify this to:

    din = 2
    rngs = nnx.Rngs(0)
    
    my_linear = functools.partial(
      nnx.Linear, din, use_bias=False, rngs=rngs
    )
    
    q_s = jax.tree.map(my_linear, session_sizes)
    
    for k in session_sizes.keys():
      print(q_s[k].kernel)