pytorchonnx

How can I get the inference compute graph of the PyTorch model?


I want to handwrite a framework to perform inference of a given neural network. The network is so complicated, so to make sure my implementation is correct, I need to know how exactly the inference process is done on the device.

I tried to use torchviz to visualize the network, but what I got seems to be the back propagation compute graph, which is really hard to understand.

Enter image description here

Then I tried to convert the PyTorch model to ONNX format, following the instruction enter link description here, but when I tried to visualize it, it seems that the original layers of the model had been separated into very small operators.

Enter image description here

I just want to get the result like this:

Enter image description here

How can I get this?


Solution

  • Try saving the model with torch.save and opening it with Netron. The last view you showed is a view of the Netron application.