Build an AI Classifier in only 30 lines of code — PyTorch Tutorial

Reza Kalantar
3 min readJan 1, 2023

--

PyTorch is a popular open-source machine learning library for Python that provides a high-level interface for working with dynamic computational graphs. It was primarily developed by Facebook’s artificial intelligence research group. Today, PyTorch is one of the most popular libraries for building deep learning frameworks. Deep learning facilitates learning of intricate feature representations in images, making it an ideal technique for classification. Here, you can learn how to build and train a simple end-to-end deep learning classification network.

Original image from Mike van den Bos and modified by the author

First, let’s start with downloading the MNIST dataset and creating train and test dataloaders. The dataset includes 60,000 train and 10,000 test examples of grayscale handwritten digits 0–9 with labels:

import torch
import torchvision
import torch.nn as nn

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data', train=True, download=True,
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=100, shuffle=True)

test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data', train=False, download=True,
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=100, shuffle=False)

Create a simple convolutional neural network (CNN) with two convolutional layers followed by two fully-connected layers:

# Create the model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = nn.functional.relu(self.conv1(x))
x = nn.functional.max_pool2d(x, 2, 2)
x = nn.functional.relu(self.conv2(x))
x = nn.functional.max_pool2d(x, 2, 2).view(-1, 4*4*50) # reshape dimensions
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x

model= CNN()

Define the cross-entropy loss criterion and optimizer:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

Finally, we can train the model using a standard PyTorch loop:

# Train the model for 5 epochs
for epoch in range(5):
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'[{epoch + 1}/5] loss: {running_loss / (i + 1)}')

Here is the output after training:

[1/5] loss: 0.2155205728672445
[2/5] loss: 0.05213625013556642
[3/5] loss: 0.035712676378704296
[4/5] loss: 0.026712203502247577
[5/5] loss: 0.02115860335577357

Great, now we have built and trained a CNN for classifying handwritten digits in just 30 lines of code!

(optional) Let’s now evaluate our model on the test set:

with torch.no_grad():
for i, (inputs, labels) in enumerate(test_loader):
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
correct = (predicted == labels).sum().item()
print(f'Batch [{i+1}/{len(test_loader)}] accuracy: {correct / len(labels)}')

Here is the output for the first 10 test batches:

Batch [1/100] accuracy: 0.99
Batch [2/100] accuracy: 1.0
Batch [3/100] accuracy: 0.99
Batch [4/100] accuracy: 1.0
Batch [5/100] accuracy: 1.0
Batch [6/100] accuracy: 0.99
Batch [7/100] accuracy: 0.99
Batch [8/100] accuracy: 0.97
Batch [9/100] accuracy: 1.0
Batch [10/100] accuracy: 0.99

The performance seems really good already, but there are several ways through which this model can be improved for more complex classification tasks. Here are three examples:

  • Make the CNN deeper with more trainable parameters
  • Use early-stopping in the training loop using a validation set to avoid overfitting
  • Visualize activation maps to identify the focus of CNN when classifying the labels

If you find this tutorial helpful or would like to reach out, feel free to get in touch with me on here, Github or Linkedin. Happy coding!

--

--

Reza Kalantar

Medical AI Researcher by Profession • Scientist/Engineer by Trade • Investor by Instinct • Explorer by Nature • Procrastinator by Choice