pytorchconv-neural-networkoverwrite

How to overwrite nn.conv2d


I want to overwrite nn.conv2d so prepared models such as resnet, alexnet etc. can use it without changing the all nn.conv2ds in the model manually.

from torchvision import models
from torch import nn

class replace_conv2d(nn.Module):
      # other codes

nn.conv2d = replace_conv2d # what I want to do
model = models.resnet18()

so resnet18 will use the replace_conv2d class instead of nn.conv2d


Solution

  • I am not sure you could overwrite the modules when they are loaded in. What you can do though is wrap the nn.Module with a function that will go through the module tree and replace nn.Conv2d with another layer implementation (for example here nn.Identity). The only trick is the fact child layers can be identified by compound keys. For example models.layer1[0].conv2 has keys "layer1", "0", and finally "conv2".

    Gather the nn.Conv2d and split their compound keys:

    convs = []
    for k, v in model.named_modules():
        if isinstance(v, nn.Conv2d):
            convs.append(k.split('.'))
    

    Build a recursive function to get a sub module from a compound key:

    inspect = lambda m, k: inspect(getattr(m, k[0]), k[1:]) if len(k)>1 else m
    

    Finally, you can iterate over the submodules and replace the layer:

    for k in convs:
        setattr(inspect(model, k), k[-1], nn.Identity())
    

    You will see all nn.Conv2d layers (whatever their depth) will be replaced:

    >>> model.layer1[0].conv2
    Identity()
    

    If you want to access the parameters of the conv layer you are about to replace, you can check its attributes:

    keys = 'in_channels', 'out_channels', 'kernel_size', \
           'stride', 'padding', 'dilation', 'groups', \
           'bias', 'padding_mode'
    
    for k in convs:
        parent = inspect(model, k)
        conv = getattr(parent, k[-1])
        setattr(parent, k[-1], nn.Conv2d(**{k: getattr(conv,k) for k in keys}))