Is using two different nn.ModuleList()
zipped lists correct to build the computational graph for training a neural net in PyTorch
? nn.ModuleList is a wrapper around Python's list with a registration of a module for training.
I'm building a network which consists of 2x interchanging types of blocks in __init__
:
def __init__(self, in_channels):
super().__init__()
self.encoder_conv_blocks = nn.ModuleList()
self.downsample_blocks = nn.ModuleList()
for out_channels in _FILTERS:
conv_block = _ConvBlock(in_channels, _CONV_BLOCK_LEN, _CONV_BLOCK_GROWTH_RATE)
downsample_block = _DownsampleBlock(conv_block.out_channels, out_channels)
self.encoder_conv_blocks.append(conv_block)
self.downsample_blocks.append(downsample_block)
in_channels = out_channels
later in forward
, I'm zipping the layers, as I need the outputs of the first type of block later in skip connections:
def forward(self, x):
skip_connections = []
for conv_block, downsample_block in zip(self.encoder_conv_blocks,
self.downsample_blocks):
x = conv_block(x)
skip_connections.append(x)
x = downsample_block(x)
However when pritting the summary torchinfo, we can see that summary of the registered methods using 2x zipped nn.ModuleList
looks different compared to the summary where one single nn.ModuleList
was used. I suspect that this can cause issues for training and inference in the future.
zip(nn.ModuleList(), nn.ModuleList())
:
========================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param #
========================================================================================================================
MyNet [16, 4, 128, 256] [16, 3, 128, 256] --
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-1 [16, 4, 128, 256] [16, 84, 128, 256] 26,360
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-2 [16, 84, 128, 256] [16, 64, 64, 128] 48,448
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-3 [16, 64, 64, 128] [16, 144, 64, 128] 70,160
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-4 [16, 144, 64, 128] [16, 128, 32, 64] 166,016
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-5 [16, 128, 32, 64] [16, 208, 32, 64] 116,880
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-6 [16, 208, 32, 64] [16, 128, 16, 32] 239,744
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-7 [16, 128, 16, 32] [16, 208, 16, 32] 116,880
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-8 [16, 208, 16, 32] [16, 128, 8, 16] 239,744
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-9 [16, 128, 8, 16] [16, 208, 8, 16] 116,880
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-10 [16, 208, 8, 16] [16, 256, 4, 8] 479,488
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-11 [16, 256, 4, 8] [16, 336, 4, 8] 210,320
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-12 [16, 336, 4, 8] [16, 256, 2, 4] 774,400
├─ModuleList: 1-13 -- -- (recursive)
│ └─_ConvBlock: 2-13 [16, 256, 2, 4] [16, 336, 2, 4] 210,320
├─ModuleList: 1-14 -- -- (recursive)
│ └─_DownsampleBlock: 2-14 [16, 336, 2, 4] [16, 512, 1, 2] 1,548,800
single nn.ModuleList()
:
MyNet [16, 4, 128, 256] [16, 3, 128, 256] --
├─ModuleList: 1-1 -- -- --
│ └─_ConvBlock: 2-1 [16, 4, 128, 256] [16, 84, 128, 256] 26,360
│ └─_DownsampleBlock: 2-2 [16, 84, 128, 256] [16, 64, 64, 128] 48,448
│ └─_ConvBlock: 2-3 [16, 64, 64, 128] [16, 144, 64, 128] 70,160
│ └─_DownsampleBlock: 2-4 [16, 144, 64, 128] [16, 128, 32, 64] 166,016
│ └─_ConvBlock: 2-5 [16, 128, 32, 64] [16, 208, 32, 64] 116,880
│ └─_DownsampleBlock: 2-6 [16, 208, 32, 64] [16, 128, 16, 32] 239,744
│ └─_ConvBlock: 2-7 [16, 128, 16, 32] [16, 208, 16, 32] 116,880
│ └─_DownsampleBlock: 2-8 [16, 208, 16, 32] [16, 128, 8, 16] 239,744
│ └─_ConvBlock: 2-9 [16, 128, 8, 16] [16, 208, 8, 16] 116,880
│ └─_DownsampleBlock: 2-10 [16, 208, 8, 16] [16, 256, 4, 8] 479,488
│ └─_ConvBlock: 2-11 [16, 256, 4, 8] [16, 336, 4, 8] 210,320
│ └─_DownsampleBlock: 2-12 [16, 336, 4, 8] [16, 256, 2, 4] 774,400
│ └─_ConvBlock: 2-13 [16, 256, 2, 4] [16, 336, 2, 4] 210,320
│ └─_DownsampleBlock: 2-14 [16, 336, 2, 4] [16, 512, 1, 2] 1,548,800
Both methods are equivalent - change in print-out is just an artifact of how torchinfo
crawls the model.
torchinfo
tracks the model's forward pass, looking at every module involved. If the same module appears more than once, it is labeled recursive
. For nn.ModuleList
objects, using an item in the same ModuleList
at different points of the forward
gets flagged as recursive simply because the ModuleList
container is showing up more than once in different places. Here's a simple example:
Example 1:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.ModuleList([nn.Linear(8, 8) for i in range(2)])
self.l2 = nn.Linear(8,8)
def forward(self, x):
x = self.l1[0](x)
x = self.l1[1](x)
x = self.l2(x)
return x
m = MyModel()
summary(m, (1, 8), depth=5)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MyModel [1, 8] --
├─ModuleList: 1-1 -- --
│ └─Linear: 2-1 [1, 8] 72
│ └─Linear: 2-2 [1, 8] 72
├─Linear: 1-2 [1, 8] 72
==========================================================================================
Total params: 216
Trainable params: 216
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
Example 2:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.ModuleList([nn.Linear(8, 8) for i in range(2)])
self.l2 = nn.Linear(8,8)
def forward(self, x):
x = self.l1[0](x)
x = self.l2(x)
x = self.l1[1](x)
return x
m = MyModel()
summary(m, (1, 8), depth=5)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MyModel [1, 8] --
├─ModuleList: 1-3 -- (recursive)
│ └─Linear: 2-1 [1, 8] 72
├─Linear: 1-2 [1, 8] 72
├─ModuleList: 1-3 -- (recursive)
│ └─Linear: 2-2 [1, 8] 72
==========================================================================================
Total params: 216
Trainable params: 216
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================
In the first example, we use all layers in the ModuleList
in order, and get no recursive
flag. In the second, we use the layers in the ModuleList
at different times, and get the recursive
flag on the ModuleList
object itself. This is just an artifact of how torchinfo
crawls the model.
As a purely style-based note, there's nothing wrong with zipping modulelists, but if you know each _ConvBlock
will be paired 1-1 with a _DownsampleBlock
, you might consider putting them into a combined module
class CombinedBlock(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.conv_block = _ConvBlock(...)
self.down_block = _DownsampleBlock(...)
def forward(self, x):
x = self.conv_block(x)
skip = x
x = self.down_block(x)
return x, skip