I'm trying to work out how to do transfer learning with flax.nnx. Below is my attempt to freeze the kernel of my nnx.Linear instance and optimize the bias. I think maybe I'm not correctly setting up the 'wrt' argument to my optimizer.
from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt
def f(x,m=2.234,b=-1.123):
return m*x+b
def compute_loss(model, inputs, obs):
prediction = model(inputs)
error = obs - prediction
loss = jnp.mean(error ** 2)
mae = jnp.mean(jnp.abs(error ) )
return loss, mae
if __name__ == '__main__':
shape = (2,55,1)
epochs = 123
rngs = nnx.Rngs(123)
model = nnx.Linear( 1, 1, rngs=rngs )
model.kernel.value = jnp.array([[2.0]]) #load pretrained kernel
skey = rngs.params()
xx = random.uniform( skey, shape, minval=-10, maxval=10 )
obs1,obs2 = f(xx)
x1,x2 = xx
loss_grad = nnx.value_and_grad(compute_loss, has_aux = True)
@nnx.scan(
in_axes=(nnx.Carry,None,None,),
out_axes=(nnx.Carry,0),
length=epochs
)
def optimizer_scan( optimizer, x, obs ):
(loss,mae), grads = loss_grad( optimizer.model, x, obs )
optimizer.update( grads )
return optimizer, (loss,mae)
transfer_params = nnx.All(nnx.PathContains("bias"))
optimizer_transfer = nnx.Optimizer(model, optax.adam(learning_rate=1e-3), wrt = transfer_params)
optimizer, (losses,maes) = optimizer_scan( optimizer_transfer, x1, obs1 )
print( ' AFTER TRAINING' )
print( 'training loss:', losses[-1] )
y1,y2 = optimizer.model(xx)
error = obs2-y2
loss = jnp.mean( error*error )
print( 'test loss:',loss )
print( 'm approximation:', optimizer.model.kernel.value )
print( 'b approximation:', optimizer.model.bias.value )
And this results in the following error:
ValueError: Mismatch custom node data: ('bias', 'kernel') != ('bias',); value: State({
'bias': VariableState(
type=Param,
value=Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
)
}).
The missing link for me was nnx.DiffState. For clarification on DiffState, see the documentation for nnx.grad() on the nnx "transforms" page:
https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html
Anyway, effectively the only changes that need be made for the code to work as intended are:
Move the declaration of transfer_params to before the value_and_grad call,
Create an nnx.DiffState object diff_state = nnx.DiffState(0,transfer_params)
Give diff_state
as the argnums keyword for nnx.value_and_grad.
And that does it!
Another helpful example of how to use nnx.DiffState with parameter filtering can be found here:
https://github.com/google/flax/issues/4167
And lastly here is the complete fixed example:
from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt
def f(x,m=2.234,b=-1.123):
return m*x+b
def compute_loss(model, inputs, obs):
prediction = model(inputs)
error = obs - prediction
loss = jnp.mean(error ** 2)
mae = jnp.mean(jnp.abs(error ) )
return loss, mae
if __name__ == '__main__':
shape = (2,55,1)
epochs = 123
rngs = nnx.Rngs(123)
model = nnx.Linear( 1, 1, rngs=rngs )
model.kernel.value = jnp.array([[2.0]]) #load pretrained kernel
skey = rngs.params()
xx = random.uniform( skey, shape, minval=-10, maxval=10 )
obs1,obs2 = f(xx)
x1,x2 = xx
transfer_params = nnx.All(nnx.PathContains("bias"))
diff_state = nnx.DiffState(0,transfer_params)
loss_grad = nnx.value_and_grad(compute_loss, argnums = diff_state, has_aux = True)
@nnx.scan(
in_axes=(nnx.Carry,None,None,),
out_axes=(nnx.Carry,0),
length=epochs
)
def optimizer_scan( optimizer, x, obs ):
(loss,mae), grads = loss_grad( optimizer.model, x, obs )
optimizer.update( grads )
return optimizer, (loss,mae)
optimizer_transfer = nnx.Optimizer(model, optax.adamw(learning_rate = 1e-3), wrt = transfer_params)
optimizer, (losses,maes) = optimizer_scan( optimizer_transfer, x1, obs1 )
print( ' AFTER TRAINING' )
print( 'training loss:', losses[-1] )
y1,y2 = optimizer.model(xx)
error = obs2-y2
loss = jnp.mean( error*error )
print( 'test loss:',loss )
print( 'm approximation:', optimizer.model.kernel.value )
print( 'b approximation:', optimizer.model.bias.value )