deep-learningpytorchpruning

Why doesn't torch pruning actually remove filters or weights?


I work with one architecture and trying to sparse it via prune. I wrote functions for pruning, here is one of them:

def prune_model_l1_unstructured(model, layer_type, proportion):
    for module in model.modules():
        if isinstance(module, layer_type):
            prune.l1_unstructured(module, 'weight', proportion)
            prune.remove(module, 'weight')
    return model

# prune model
prune_model_l1_unstructured(model, nn.Conv2d, 0.5)

It prunes some weights (change them to zeros). But prune.remove just deletes original weights and keeps zeros instead. Total amount of parameters still same (I checked it). The model's file (model.pt) size still the same too. And the model's "speed" still the same after it. I tried also global pruning and structured L1 pruning, results are the same. So how this can help to improve model's performance time? Why aren't the weights being removed and how to remove pruned connections?


Solution

  • TLDR; PyTorch prune's function just works as a weight mask, that's all it does. There are no memory savings associated with using torch.nn.utils.prune.

    As the documentation page for torch.nn.utils.prune.remove states:

    Removes the pruning reparameterization from a module and the pruning method from the forward hook.

    In effect, this means it will remove the mask - that prune.l1_unstructured added - from the parameter. As a side effect, removing the prune will imply having zeros on the previously masked values but these won't stay as 0s. In the end, PyTorch prune will only take more memory compared to not using it. So this is not actually the functionality you are looking for.

    You can read more on this comment.


    Here is an example:

    >>> module = nn.Linear(10,3)
    >>> prune.l1_unstructured(module, name='weight', amount=0.3)
    

    The weight parameters are masked:

    >>> module.weight
    tensor([[-0.0000,  0.0000, -0.1397, -0.0942,  0.0000,  0.0000,  0.0000, -0.1452,
              0.0401,  0.1098],
            [ 0.2909, -0.0000,  0.2871,  0.1725,  0.0000,  0.0587,  0.0795, -0.1253,
              0.0764, -0.2569],
            [ 0.0000, -0.3054, -0.2722,  0.2414,  0.1737, -0.0000, -0.2825,  0.0685,
              0.1616,  0.1095]], grad_fn=<MulBackward0>)
    

    Here is the mask:

    >>> module.weight_mask
    tensor([[0., 0., 1., 1., 0., 0., 0., 1., 1., 1.],
            [1., 0., 1., 1., 0., 1., 1., 1., 1., 1.],
            [0., 1., 1., 1., 1., 0., 1., 1., 1., 1.]])
    

    Notice that when applying prune.remove, the pruning is removed. And, the masked values remain at zero but are "unfrozen":

    >>> prune.remove(module, 'weight')
    
    >>> module.weight
    tensor([[-0.0000,  0.0000, -0.1397, -0.0942,  0.0000,  0.0000,  0.0000, -0.1452,
              0.0401,  0.1098],
            [ 0.2909, -0.0000,  0.2871,  0.1725,  0.0000,  0.0587,  0.0795, -0.1253,
              0.0764, -0.2569],
            [ 0.0000, -0.3054, -0.2722,  0.2414,  0.1737, -0.0000, -0.2825,  0.0685,
              0.1616,  0.1095]], grad_fn=<MulBackward0>)
    

    And the mask is gone:

    >>> hasattr(module, 'weight_mask')
    False