I would like to trace the grads through the self.put_variable. Is there anyway to make that possible? Or another way to update the param supplied to the module that is traced?
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
self.put_variable("params","b",(x@W+b).reshape(5,))
return jnp.sum(x+b)
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
#x,param = module.apply(param,x,mutable=["params"])
#print(param)
print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))
my output grads are :
FrozenDict({
params: {
W: DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
What shows that it doesnt trace the grads through the self.variable_put method, as grads to W are all zero, while b clearly relies upon W.
Just like @jakevdp noted the test above is incorrect as b is still tied to the previous b.
https://github.com/google/flax/discussions/2215 said that self.put_variable is traced.
Testing if that is actually the case using the code below:
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
b = x@W+b #update the b variable else it is still tied to the previous one.
self.put_variable("params","b",(b).reshape(5,))
return jnp.sum(x+b)
def test_update(param,x):
_, param = module.apply(param,x,mutable=["params"])
return jnp.sum(param["params"]["b"]+x),param
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
print(grad(test_update,has_aux=True)(param,x))
output:
FrozenDict({
params: {
W: DeviceArray([[ 0.01678762, 0.00234134, 0.00906202, 0.00027337,
0.00599653],
[-0.00729604, -0.00417799, 0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531, 0.01898266, -0.01733185,
-0.00616944],
[-0.00806503, 0.00409351, 0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197, 0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.00905032, -0.00574646, 0.01621638, -0.01165553,
-0.0285466 ], dtype=float32),
},
})
(FrozenDict({
params: {
W: DeviceArray([[-1.1489547 , -1.1489547 , -1.1489547 , -1.1489547 ,
-1.1489547 ],
[-2.0069852 , -2.0069852 , -2.0069852 , -2.0069852 ,
-2.0069852 ],
[ 0.98777294, 0.98777294, 0.98777294, 0.98777294,
0.98777294],
[ 0.9311977 , 0.9311977 , 0.9311977 , 0.9311977 ,
0.9311977 ],
[-0.2883922 , -0.2883922 , -0.2883922 , -0.2883922 ,
-0.2883922 ]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
}), FrozenDict({
params: {
W: DeviceArray([[ 0.01678762, 0.00234134, 0.00906202, 0.00027337,
0.00599653],
[-0.00729604, -0.00417799, 0.00172333, -0.00566238,
0.0097266 ],
[ 0.00378883, -0.00901531, 0.01898266, -0.01733185,
-0.00616944],
[-0.00806503, 0.00409351, 0.0179838 , -0.00238476,
0.00252594],
[ 0.00398197, 0.00030245, -0.00640218, -0.00145424,
0.00956188]], dtype=float32),
b: DeviceArray([-0.01861148, -0.00523183, 0.03968921, -0.01952654,
-0.06145691], dtype=float32),
},
}))
The first FrozenDict is the original parameters.
The second FrozenDict is the grads, clearly being traced through self.put_variable.
The last FrozenDict is the parameters, where we can see that b is correctly updated.