pythonjaxflax

is there a way to trace grads through self.put_variable method in flax?


I would like to trace the grads through the self.put_variable. Is there anyway to make that possible? Or another way to update the param supplied to the module that is traced?

import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 


class network(nn.Module):
    input_size : int 
    output_size : int 
    @nn.compact
    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

      
        self.put_variable("params","b",(x@W+b).reshape(5,))  
    
        return jnp.sum(x+b)


if __name__ == "__main__":
    key = random.PRNGKey(0)
    key_x,key_param,key = random.split(key,3)
    x = random.normal(key_x,(1,5))

    module = network(5,5)
    param = module.init(key_param,x)
    print(param)
    #x,param = module.apply(param,x,mutable=["params"])
    #print(param)
    print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))

my output grads are :

FrozenDict({
    params: {
        W: DeviceArray([[0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.]], dtype=float32),
        b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
    },

What shows that it doesnt trace the grads through the self.variable_put method, as grads to W are all zero, while b clearly relies upon W.


Solution

  • Just like @jakevdp noted the test above is incorrect as b is still tied to the previous b.
    https://github.com/google/flax/discussions/2215 said that self.put_variable is traced.

    Testing if that is actually the case using the code below:

    import jax 
    from jax import numpy as jnp 
    from jax import grad,random,jit,vmap 
    import flax 
    from flax import linen as nn 
    
    class network(nn.Module):
        input_size : int 
        output_size : int 
        @nn.compact
        def __call__(self,x):
            W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
            b = self.param('b',nn.initializers.normal(),(self.output_size,))
    
            b = x@W+b #update the b variable else it is still tied to the previous one.
            self.put_variable("params","b",(b).reshape(5,))  
         
            return jnp.sum(x+b)
    
    def test_update(param,x):
        _, param = module.apply(param,x,mutable=["params"])
        return jnp.sum(param["params"]["b"]+x),param 
    
    if __name__ == "__main__":
        key = random.PRNGKey(0)
        key_x,key_param,key = random.split(key,3)
        x = random.normal(key_x,(1,5))
    
        module = network(5,5)
        param = module.init(key_param,x)
        print(param)
    
        print(grad(test_update,has_aux=True)(param,x))
    

    output:

    FrozenDict({
        params: {
            W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
                           0.00599653],
                         [-0.00729604, -0.00417799,  0.00172333, -0.00566238,
                           0.0097266 ],
                         [ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
                          -0.00616944],
                         [-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
                           0.00252594],
                         [ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
                           0.00956188]], dtype=float32),
            b: DeviceArray([-0.00905032, -0.00574646,  0.01621638, -0.01165553,
                         -0.0285466 ], dtype=float32),
        },
    })
    (FrozenDict({
        params: {
            W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
                          -1.1489547 ],
                         [-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
                          -2.0069852 ],
                         [ 0.98777294,  0.98777294,  0.98777294,  0.98777294,
                           0.98777294],
                         [ 0.9311977 ,  0.9311977 ,  0.9311977 ,  0.9311977 ,
                           0.9311977 ],
                         [-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
                          -0.2883922 ]], dtype=float32),
            b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
        },
    }), FrozenDict({
        params: {
            W: DeviceArray([[ 0.01678762,  0.00234134,  0.00906202,  0.00027337,
                           0.00599653],
                         [-0.00729604, -0.00417799,  0.00172333, -0.00566238,
                           0.0097266 ],
                         [ 0.00378883, -0.00901531,  0.01898266, -0.01733185,
                          -0.00616944],
                         [-0.00806503,  0.00409351,  0.0179838 , -0.00238476,
                           0.00252594],
                         [ 0.00398197,  0.00030245, -0.00640218, -0.00145424,
                           0.00956188]], dtype=float32),
            b: DeviceArray([-0.01861148, -0.00523183,  0.03968921, -0.01952654,
                         -0.06145691], dtype=float32),
        },
    }))
    
    

    The first FrozenDict is the original parameters.
    The second FrozenDict is the grads, clearly being traced through self.put_variable.
    The last FrozenDict is the parameters, where we can see that b is correctly updated.