pythonmachine-learningpytorch

Should nested modules with shared weights be an nn.Module object parameter or not?


I would like two torch.nn.Module classes to share part of their architecture and weights, as in the example below:

from torch import nn

class SharedBlock(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        self.block = nn.Sequential(
            # Define some block architecture here...
        )

    def forward(self, x):
        return self.block(x)

class MyNestedModule(nn.Module):
    def __init__(self, shared_block: nn.Module, *args, **kwargs):
        super().__init__()

        self.linear = nn.Linear(...)
        self.shared_block = shared_block

    def forward(self, x):
        return self.shared_block(self.linear(x))

class MyModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

        
        # SHOULD THIS BE:
        shared_block = SharedBlock(*args, **kwargs)
        # OR:
        self.shared_block = SharedBlock(*args, **kwargs)  # Note: self.
        # ...AND WHAT IS THE DIFFERENCE, IF ANY?


        self.nested1 = MyNestedModule(shared_block, *args, **kwargs)
        self.nested2 = MyNestedModule(shared_block, *args, **kwargs)

    def forward(self, x):
        x_1, x_2 = torch.split(x, x.shape[0] // 2, dim=0)
        y_1 = self.nested1(x_1)
        y_2 = self.nested2(y_2)
        return y_1, y_2

I would like to know whether shared_block should be an object parameter of MyModule. I assume it does not, since it is set as an object parameter in both the MyNestedModule class objects so it should be registered in torch grad but if I did create it as an object parameter in MyModule what would happen?


Solution

  • It doesn't matter, the parameters are tracked both ways. If you use shared_block = ..., the parameters in shared_block will be referenced in your state dict (model.state_dict()) twice, once for self.nested1 and again for self.nested2.

    If you use the self.shared_block = ... approach, the state dict will reference the parameters a third time in MyModule itself.

    Either way, the parameters are tracked and model.parameters() will return a non-duplicated set of parameters.

    You can run this code to look at a simplified version

    import torch
    from torch import nn
    
    class SharedBlock(nn.Module):
        def __init__(self):
            super().__init__()
    
            self.block = nn.Linear(8, 8)
    
        def forward(self, x):
            return self.block(x)
    
    class MyNestedModule(nn.Module):
        def __init__(self, shared_block):
            super().__init__()
    
            self.shared_block = shared_block
    
        def forward(self, x):
            return self.shared_block(x)
        
    class MyModule1(nn.Module):
        def __init__(self):
            super().__init__()
            shared_block = SharedBlock()
    
            self.nested1 = MyNestedModule(shared_block)
            self.nested2 = MyNestedModule(shared_block)
    
        def forward(self, x):
            x_1, x_2 = torch.split(x, x.shape[0] // 2, dim=0)
            y_1 = self.nested1(x_1)
            y_2 = self.nested2(x_2)
            return y_1, y_2
        
    class MyModule2(nn.Module):
        def __init__(self):
            super().__init__()
            self.shared_block = SharedBlock()
    
            self.nested1 = MyNestedModule(self.shared_block)
            self.nested2 = MyNestedModule(self.shared_block)
    
        def forward(self, x):
            x_1, x_2 = torch.split(x, x.shape[0] // 2, dim=0)
            y_1 = self.nested1(x_1)
            y_2 = self.nested2(x_2)
            return y_1, y_2
        
        
    model1 = MyModule1()
    print(model1.state_dict())
    print(list(model1.parameters()))
    
    model2 = MyModule2()
    print(model2.state_dict())
    print(list(model2.parameters()))