juliaflux.jl

How to freeze layer parameters in Flux.jl


I am currently working on a transfer learning problem so I would like to freeze most of my layers such that the model is not re-trained, and only the final layer's weights are modified. How can I choose which layers to freeze with Flux.jl?


Solution

  • Flux provides a simple interface to do this which is to only pass in the layers you want to be modified to the Flux.params() function as shown below:

    m = Chain(
          Dense(784, 64, relu),
          Dense(64, 64, relu),
          Dense(32, 10)
        )
    
    ps = Flux.params(m[3:end])
    

    In the above example, we chose to only update the final Dense layer (which is commonly what you do in a transfer learning example).

    You can see a full example with more build up in the Flux.jl tutorial on transfer learning: https://fluxml.ai/tutorials/2020/10/18/transfer-learning.html