Print PyTorch model.summary() like Keras

Reza Kalantar
3 min readDec 30, 2022

--

If you have worked with both PyTorch and Keras, you already know that these frameworks are eerily similar! However, sometimes there are some useful functions in one that are not natively included in the other. Printing network summary is one of them. It can give you quick access to model architecture, kernel filters and trainable parameters. Frankly, you can now do this in PyTorch with just two lines of code!

Let’s create a simple PyTorch model:

import torch
import torch.nn as nn

# Define the model class
class MyModel(nn.Module):
def __init__(self,num_classes):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32*7*7, 128)
self.fc2 = nn.Linear(128, num_classes)

def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = x.view(-1, 32*7*7)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x

# Create an instance of the model
model = MyModel(num_classes=10)

# Print the PyTorch model parameters
print(model)

Here is the output:

MyModel(
(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1): Linear(in_features=1568, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)

But, this can get quite messy, especially as the model gets deeper and has a more complex architecture. Thankfully, there is a library called torchsummary, that allows you to print a clean Keras-like summary for a PyTorch model.

First, you will need to install the library. You can do so by typing the following command in the terminal:

pip install torchsummary

Then, import the library and print the model summary:

import torchsummary

# You need to define input size to calcualte parameters
torchsummary.summary(model, input_size=(3, 224, 224))

This time, the output is:

A simple PyTorch model summary

We can also use this to explore other famous architectures, such as AlexNet:

from torchvision import models

# Create an instance of AlexNet from TorchVision
alexnet = models.AlexNet()
torchsummary.summary(alexnet, input_size=(3, 224, 224))
AlexNet model summary

Awesome! Now we have a clean output with the layout resembling the one in Keras. There is also information on network trainable parameters, input size and estimated model size which are important considerations before training deep learning models.

All credits to Shubham Chandel for creating this library.

If you enjoy my contents, feel free to get in touch with me on here, Github or Linkedin. Happy coding!

Read more articles:

--

--

Reza Kalantar

Medical AI Researcher by Profession • Scientist/Engineer by Trade • Investor by Instinct • Explorer by Nature • Procrastinator by Choice