Imagine a network like this:
import torch
import torch.nn as nn
class CoolCNN(nn.Module):
def __init__(self):
super(CoolCNN, self).__init__()
self.initial_conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.parallel_conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.secondary_conv = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fully_connected1 = nn.Linear(32 * 8 * 8, 128)
self.output_layer = nn.Linear(128, 10)
def forward(self, x):
main_path = self.max_pool(torch.relu(self.initial_conv(x)))
parallel_path = self.max_pool(torch.relu(self.parallel_conv(x)))
x = (main_path + parallel_path) / 2
x = self.max_pool(torch.relu(self.secondary_conv(x)))
x = x.view(-1, 32 * 8 * 8)
x = torch.relu(self.fully_connected1(x))
x = self.output_layer(x)
return x
It has this tree-like structure:
CoolCNN
└── Forward Pass
├── parallel_path
│ ├── parallel_conv (Conv2d)
│ ├── ReLU Activation
│ └── max_pool (MaxPool2d)
│
├── main_path
│ ├── initial_conv (Conv2d)
│ ├── ReLU Activation
│ └── max_pool (MaxPool2d)
│
├── Average main_path and parallel_path
│
├── secondary_conv (Conv2d)
├── ReLU Activation
└── max_pool (MaxPool2d)
│
├── Flatten the Tensor
│
├── fully_connected1 (Linear)
├── ReLU Activation
│
└── output_layer (Linear)
As you can see, the main and parallel paths being parallel branches gets accurately represented in the above illustrated tree representation.
I wanted to know if it's possible to somehow extract this tree-like structure from just a forward pass through the network without relying on the backward pass at all. There are libraries like torchviz which use the computation graph from backward pass which I don't want to use.
The only approach that I could find was forward hooks. That is in fact how torchsummary prints the model's summary.
However, with forward hooks, all you get to know is which order the nodes get called in. So it is just a topological sort of the underlying graph and it's impossible to accurately reconstruct a graph from just its topological sort. This means it's impossible to accurately derive the actual unique tree representation from forward hooks alone (at least that is my understanding).
I have already seen this question where the OP seemed to be happy without an exact graph reconstruction, so it doesn't address my need.
So is there a way by which we can reconstruct the computation graph from a forward pass alone?
I didn't find an out-of-the-box way to do this, so I ended up writing a library that does this forward pass tracing, and then visualizes the forward pass of the model as a graph.
The way I did the tracing was by (temporarily) wrapping all the standard Pytorch operations and modules' forward method to do some additional bookkeeping to build the graph data structure.
In simplified form, it looked like this:
# save original operations here to revert everything back after tracing the model
original_operations = {}
def wrap_operation(python_module, operation):
# save a reference to the original operation
original_operations[get_hashable_key(python_module, operation)] = operation
def wrapped_operation(*args, **kwargs):
# Do the necessary pre-call bookkeeping
do_pre_call_bookkeeping()
# Call the original operation
result = operation(*args, **kwargs)
do_post_call_bookkeeping()
return result
setattr(python_module, func_name, wrapped_operation)
for python_module, operation in LONG_LIST_OF_PYTORCH_OPS:
# wrap all the inbuilt Pytorch operations
wrap_operation(python_module, operation)
The bookkeeping in the wrapped functions involved writing an extra field to each tensor that stored the "source operation" of it. This extra field was used by the downstream (consumer) operation to know what the source of the tensor was, so that it could add an edge in the graph adjacency list.
Here is an illustration of it:
In simplified code:
adj_list = {}
def do_post_call_bookkeeping(module, operation, tensor_output):
# Set a "marker" on the output tensor so that whoever consumes it
# knows which operation produced it
tensor_output._source_node = get_hashable_key(module, operation)
def do_pre_call_bookkeeping(module, operation, tensor_input):
source_node = tensor_input._source_node
# Add a link from the producer of the tensor to this node (the consumer)
adj_list[source_node].append(get_hashable_key(module, operation))
This was the essence of how I ended up implementing it. I also wrote a more detailed blogpost about it.