pythonmachine-learningkerastensorflow2.0pruning

How to set prunable layers for tfmot.sparsity.keras.prune_low_magnitude?


I am applying the pruning function from tensorflow_model_optimization, tfmot.sparsity.keras.prune_low_magnitude() to MobileNetV2.

Is there any way to set only some layers of the model to be prunable? For training, there is a method "set_trainable", but I haven't found any equivalent for pruning.

Any ideas or comments will be appreciated! :)


Solution

  • In the end I found that you can also apply prune_low_magnitude() per layer.

    So the workaround would be to define a list containing the names or types of the layers that shall be pruned, and iterate the layer-wise pruning over all layers in this list.