4 Visualization Tools for PyTorch

Visualizing deep neural networks with ease

Oliver Lövström
Internet of Technology
3 min readApr 5, 2024

--

Visualizing neural networks is essential for debugging, documentation, and more. Here are the top four visualization tools I use with PyTorch.

Photo by Steve Johnson on Unsplash

Today, we’ll be working with a simple Convolutional network, but feel free to make adjustments to the code to fit your needs; here’s the model:

class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=32 * 12 * 12, out_features=256)
self.fc2 = nn.Linear(in_features=256, out_features=7)

def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x)))
x = self.pool2(torch.relu(self.conv2(x)))
x = x.view(-1, 32 * 12 * 12)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

Torchinfo

Torchinfo (formerly torch-summary) is a Python package for visualizing neural networks similar to Tensorflow:

  • Installation: pip install torchinfo

Code for printing summary:

from torchinfo import summary

model = SimpleCNN()
summary(model, input_size=(1, 48, 48))

Output:

==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SimpleCNN [1, 7] --
├─Conv2d: 1-1 [16, 48, 48] 160
├─MaxPool2d: 1-2 [16, 24, 24] --
├─Conv2d: 1-3 [32, 24, 24] 4,640
├─MaxPool2d: 1-4 [32, 12, 12] --
├─Linear: 1-5 [1, 256] 1,179,904
├─Linear: 1-6 [1, 7] 1,799
==========================================================================================
Total params: 1,186,503
Trainable params: 1,186,503
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 4.87
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.44
Params size (MB): 4.75
Estimated Total Size (MB): 5.20
==========================================================================================

PyTorchViz

PyTorchViz is a Python package for visualizing neural networks as a graph.

  • Installation: pip install torchviz

Code for generating graph:

from torchviz import make_dot
import torch

model = SimpleCNN()
input = torch.randn(1, 1, 48, 48)
output = model(input)
graph = make_dot(output, params=dict(model.named_parameters()))
graph.render("SimpleCNN", format="png", cleanup=True)

Output:

Image by Author

Netron

Another neural network plotting tool is Netron. Export your model as a onnx file and upload to netron.app.

  • Installation: pip install onnx

Code for exporting to onnx:

import torch

model = SimpleCNN()
input = torch.randn(1, 1, 48, 48)
torch.onnx.export(model, input, "SimpleCNN.onnx")

Visualizing in Netron:

Image by Author

Matplotlib

If you want more control over the visualization, I recommend using Matplotlib.

Code for visualizing filters of a convolutional layer:

model = SimpleCNN()
filters = model.conv1.weight.detach().numpy()
for i in range(filters.shape[0]):
plt.imshow(filters[i, 0], cmap="gray")
plt.axis("off")
plt.savefig(f"filter_{i}.png", bbox_inches="tight", pad_inches=0)
plt.close()

Output:

Image by Author

However, Matplotlib can be used for all types of visualization. An example of 3D-visualization of YOLOv8 backbone:

Image by Author

Further Reading

If you want to learn more about programming and, specifically, machine learning, see the following course:

Note: If you use my links to order, I’ll get a small kickback.

--

--