pythonmachine-learningdeep-learningpytorchcomputer-vision

Using zip() on two nn.ModuleList


Is using two different nn.ModuleList() zipped lists correct to build the computational graph for training a neural net in PyTorch? nn.ModuleList is a wrapper around Python's list with a registration of a module for training.

I'm building a network which consists of 2x interchanging types of blocks in __init__:

    def __init__(self, in_channels):
        super().__init__()

        self.encoder_conv_blocks = nn.ModuleList()
        self.downsample_blocks = nn.ModuleList()

        for out_channels in _FILTERS:
            conv_block = _ConvBlock(in_channels, _CONV_BLOCK_LEN, _CONV_BLOCK_GROWTH_RATE)
            downsample_block = _DownsampleBlock(conv_block.out_channels, out_channels)

            self.encoder_conv_blocks.append(conv_block)
            self.downsample_blocks.append(downsample_block)

            in_channels = out_channels

later in forward, I'm zipping the layers, as I need the outputs of the first type of block later in skip connections:

    def forward(self, x):
        skip_connections = []
        
        for conv_block, downsample_block in zip(self.encoder_conv_blocks,
                                                self.downsample_blocks):
            x = conv_block(x)
            skip_connections.append(x)
            x = downsample_block(x)

However when pritting the summary torchinfo, we can see that summary of the registered methods using 2x zipped nn.ModuleList looks different compared to the summary where one single nn.ModuleList was used. I suspect that this can cause issues for training and inference in the future.

zip(nn.ModuleList(), nn.ModuleList()):

========================================================================================================================
Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
========================================================================================================================
MyNet                                     [16, 4, 128, 256]         [16, 3, 128, 256]         --
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-1                        [16, 4, 128, 256]         [16, 84, 128, 256]        26,360
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-2                  [16, 84, 128, 256]        [16, 64, 64, 128]         48,448
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-3                        [16, 64, 64, 128]         [16, 144, 64, 128]        70,160
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-4                  [16, 144, 64, 128]        [16, 128, 32, 64]         166,016
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-5                        [16, 128, 32, 64]         [16, 208, 32, 64]         116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-6                  [16, 208, 32, 64]         [16, 128, 16, 32]         239,744
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-7                        [16, 128, 16, 32]         [16, 208, 16, 32]         116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-8                  [16, 208, 16, 32]         [16, 128, 8, 16]          239,744
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-9                        [16, 128, 8, 16]          [16, 208, 8, 16]          116,880
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-10                 [16, 208, 8, 16]          [16, 256, 4, 8]           479,488
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-11                       [16, 256, 4, 8]           [16, 336, 4, 8]           210,320
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-12                 [16, 336, 4, 8]           [16, 256, 2, 4]           774,400
├─ModuleList: 1-13                            --                        --                        (recursive)
│    └─_ConvBlock: 2-13                       [16, 256, 2, 4]           [16, 336, 2, 4]           210,320
├─ModuleList: 1-14                            --                        --                        (recursive)
│    └─_DownsampleBlock: 2-14                 [16, 336, 2, 4]           [16, 512, 1, 2]           1,548,800

single nn.ModuleList():

MyNet                                     [16, 4, 128, 256]         [16, 3, 128, 256]         --
├─ModuleList: 1-1                             --                        --                        --
│    └─_ConvBlock: 2-1                        [16, 4, 128, 256]         [16, 84, 128, 256]        26,360
│    └─_DownsampleBlock: 2-2                  [16, 84, 128, 256]        [16, 64, 64, 128]         48,448
│    └─_ConvBlock: 2-3                        [16, 64, 64, 128]         [16, 144, 64, 128]        70,160
│    └─_DownsampleBlock: 2-4                  [16, 144, 64, 128]        [16, 128, 32, 64]         166,016
│    └─_ConvBlock: 2-5                        [16, 128, 32, 64]         [16, 208, 32, 64]         116,880
│    └─_DownsampleBlock: 2-6                  [16, 208, 32, 64]         [16, 128, 16, 32]         239,744
│    └─_ConvBlock: 2-7                        [16, 128, 16, 32]         [16, 208, 16, 32]         116,880
│    └─_DownsampleBlock: 2-8                  [16, 208, 16, 32]         [16, 128, 8, 16]          239,744
│    └─_ConvBlock: 2-9                        [16, 128, 8, 16]          [16, 208, 8, 16]          116,880
│    └─_DownsampleBlock: 2-10                 [16, 208, 8, 16]          [16, 256, 4, 8]           479,488
│    └─_ConvBlock: 2-11                       [16, 256, 4, 8]           [16, 336, 4, 8]           210,320
│    └─_DownsampleBlock: 2-12                 [16, 336, 4, 8]           [16, 256, 2, 4]           774,400
│    └─_ConvBlock: 2-13                       [16, 256, 2, 4]           [16, 336, 2, 4]           210,320
│    └─_DownsampleBlock: 2-14                 [16, 336, 2, 4]           [16, 512, 1, 2]           1,548,800

Solution

  • Both methods are equivalent - change in print-out is just an artifact of how torchinfo crawls the model.

    torchinfo tracks the model's forward pass, looking at every module involved. If the same module appears more than once, it is labeled recursive. For nn.ModuleList objects, using an item in the same ModuleList at different points of the forward gets flagged as recursive simply because the ModuleList container is showing up more than once in different places. Here's a simple example:

    Example 1:

    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.l1 = nn.ModuleList([nn.Linear(8, 8) for i in range(2)])
            self.l2 = nn.Linear(8,8)
            
        def forward(self, x):
            x = self.l1[0](x)
            x = self.l1[1](x)
            x = self.l2(x)
            return x
    
    m = MyModel()
    summary(m, (1, 8), depth=5)
    
    ==========================================================================================
    Layer (type:depth-idx)                   Output Shape              Param #
    ==========================================================================================
    MyModel                                  [1, 8]                    --
    ├─ModuleList: 1-1                        --                        --
    │    └─Linear: 2-1                       [1, 8]                    72
    │    └─Linear: 2-2                       [1, 8]                    72
    ├─Linear: 1-2                            [1, 8]                    72
    ==========================================================================================
    Total params: 216
    Trainable params: 216
    Non-trainable params: 0
    Total mult-adds (M): 0.00
    ==========================================================================================
    Input size (MB): 0.00
    Forward/backward pass size (MB): 0.00
    Params size (MB): 0.00
    Estimated Total Size (MB): 0.00
    ==========================================================================================
    

    Example 2:

    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.l1 = nn.ModuleList([nn.Linear(8, 8) for i in range(2)])
            self.l2 = nn.Linear(8,8)
            
        def forward(self, x):
            x = self.l1[0](x)
            x = self.l2(x)
            x = self.l1[1](x)
            return x
    
    m = MyModel()
    summary(m, (1, 8), depth=5)
    
    ==========================================================================================
    Layer (type:depth-idx)                   Output Shape              Param #
    ==========================================================================================
    MyModel                                  [1, 8]                    --
    ├─ModuleList: 1-3                        --                        (recursive)
    │    └─Linear: 2-1                       [1, 8]                    72
    ├─Linear: 1-2                            [1, 8]                    72
    ├─ModuleList: 1-3                        --                        (recursive)
    │    └─Linear: 2-2                       [1, 8]                    72
    ==========================================================================================
    Total params: 216
    Trainable params: 216
    Non-trainable params: 0
    Total mult-adds (M): 0.00
    ==========================================================================================
    Input size (MB): 0.00
    Forward/backward pass size (MB): 0.00
    Params size (MB): 0.00
    Estimated Total Size (MB): 0.00
    ==========================================================================================
    

    In the first example, we use all layers in the ModuleList in order, and get no recursive flag. In the second, we use the layers in the ModuleList at different times, and get the recursive flag on the ModuleList object itself. This is just an artifact of how torchinfo crawls the model.

    As a purely style-based note, there's nothing wrong with zipping modulelists, but if you know each _ConvBlock will be paired 1-1 with a _DownsampleBlock, you might consider putting them into a combined module

    class CombinedBlock(nn.Module):
        def __init__(self, *args, **kwargs):
            super().__init__()
            self.conv_block = _ConvBlock(...)
            self.down_block = _DownsampleBlock(...)
        def forward(self, x):
            x = self.conv_block(x)
            skip = x
            x = self.down_block(x)
            return x, skip