pythonpytorch

How to extract a pytorch network container class name


Assume I have a network created as follows:

p = torch.nn.Sequential(torch.nn.Linear(self.inputSize, self.outputSize))

I know that I can print the network with:

print(p)

and get:

Sequential(
(0): Linear(in_features=22, out_features=3, bias=True)
)

I want to pretty print the network thus, I can use

for name, module in p.named_children():   
    print(f'{name:>10} {module}')

to print each network layer's name. For the example network above I'd get:

0 Linear(in_features=22, out_features=3, bias=True)

But how to I get the 'Sequential' part? It's the nn module container class so is the only method to do this to dissect the class name returned by type (<class 'torch.nn.modules.container.Sequential'>)?


Solution

  • You can get the class of the top-level container of the network using the .__class__ attribute. To get just the name, use .__class__.__name__.

    print(p.__class__.__name__)
    for name, module in p.named_children():   
        print(f'   {name:<2} {module}')
    

    Prints:

    Sequential
      0   Linear(in_features=22, out_features=3, bias=True)