pythonjaxflax

Freezing filtered parameter collections with Flax.nnx


I'm trying to work out how to do transfer learning with flax.nnx. Below is my attempt to freeze the kernel of my nnx.Linear instance and optimize the bias. I think maybe I'm not correctly setting up the 'wrt' argument to my optimizer.

from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt

def f(x,m=2.234,b=-1.123):
    return m*x+b

def compute_loss(model, inputs, obs):
    prediction = model(inputs)
    error = obs - prediction
    loss = jnp.mean(error ** 2)
    mae = jnp.mean(jnp.abs(error ) )
    return loss, mae

if __name__ == '__main__':
    shape = (2,55,1)
    epochs = 123

    rngs = nnx.Rngs(123)
    model = nnx.Linear( 1, 1, rngs=rngs )

    model.kernel.value = jnp.array([[2.0]]) #load pretrained kernel  

    skey = rngs.params()
    xx = random.uniform( skey, shape, minval=-10, maxval=10 ) 
    obs1,obs2 = f(xx)
    x1,x2 = xx
    
    loss_grad = nnx.value_and_grad(compute_loss, has_aux = True)
    @nnx.scan(
        in_axes=(nnx.Carry,None,None,),
        out_axes=(nnx.Carry,0),
        length=epochs
    )
    def optimizer_scan( optimizer, x, obs ):
        (loss,mae), grads = loss_grad( optimizer.model, x, obs )        
        optimizer.update( grads )
        return optimizer, (loss,mae)

    transfer_params = nnx.All(nnx.PathContains("bias"))
    optimizer_transfer = nnx.Optimizer(model, optax.adam(learning_rate=1e-3), wrt = transfer_params)

    optimizer, (losses,maes) = optimizer_scan( optimizer_transfer, x1, obs1 )

    print( ' AFTER TRAINING' )
    print( 'training loss:', losses[-1] )

    y1,y2 = optimizer.model(xx)
    error = obs2-y2
    loss = jnp.mean( error*error )
    print( 'test loss:',loss )
    print( 'm approximation:', optimizer.model.kernel.value )
    print( 'b approximation:', optimizer.model.bias.value )

And this results in the following error:

ValueError: Mismatch custom node data: ('bias', 'kernel') != ('bias',); value: State({
  'bias': VariableState(
    type=Param,
    value=Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
  )
}).

Solution

  • The missing link for me was nnx.DiffState. For clarification on DiffState, see the documentation for nnx.grad() on the nnx "transforms" page:

    https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html

    Anyway, effectively the only changes that need be made for the code to work as intended are:

    1. Move the declaration of transfer_params to before the value_and_grad call,

    2. Create an nnx.DiffState object diff_state = nnx.DiffState(0,transfer_params)

    3. Give diff_state as the argnums keyword for nnx.value_and_grad.

    And that does it!

    Another helpful example of how to use nnx.DiffState with parameter filtering can be found here:

    https://github.com/google/flax/issues/4167

    And lastly here is the complete fixed example:

    from jax import numpy as jnp
    from jax import random
    from flax import nnx
    import optax
    from matplotlib import pyplot as plt
    
    def f(x,m=2.234,b=-1.123):
        return m*x+b
    
    def compute_loss(model, inputs, obs):
        prediction = model(inputs)
        error = obs - prediction
        loss = jnp.mean(error ** 2)
        mae = jnp.mean(jnp.abs(error ) )
        return loss, mae
    
    if __name__ == '__main__':
        shape = (2,55,1)
        epochs = 123
    
        rngs = nnx.Rngs(123)
        model = nnx.Linear( 1, 1, rngs=rngs )
    
        model.kernel.value = jnp.array([[2.0]]) #load pretrained kernel
    
        skey = rngs.params()
        xx = random.uniform( skey, shape, minval=-10, maxval=10 ) 
        obs1,obs2 = f(xx)
        x1,x2 = xx
    
        transfer_params = nnx.All(nnx.PathContains("bias"))
        diff_state = nnx.DiffState(0,transfer_params)
        
        loss_grad = nnx.value_and_grad(compute_loss, argnums = diff_state, has_aux = True)
        @nnx.scan(
            in_axes=(nnx.Carry,None,None,),
            out_axes=(nnx.Carry,0),
            length=epochs
        )
        def optimizer_scan( optimizer, x, obs ):
            (loss,mae), grads = loss_grad( optimizer.model, x, obs )        
            optimizer.update( grads )
            return optimizer, (loss,mae)
    
        optimizer_transfer = nnx.Optimizer(model, optax.adamw(learning_rate = 1e-3), wrt = transfer_params)
    
        optimizer, (losses,maes) = optimizer_scan( optimizer_transfer, x1, obs1 )
    
        print( ' AFTER TRAINING' )
        print( 'training loss:', losses[-1] )
    
        y1,y2 = optimizer.model(xx)
        error = obs2-y2
        loss = jnp.mean( error*error )
        print( 'test loss:',loss )
        print( 'm approximation:', optimizer.model.kernel.value )
        print( 'b approximation:', optimizer.model.bias.value )