Pipeline for every PyTorch Image Classification Problem / Creating Dataset

→ Creating Efficient Datasets for PyTorch Image Classification Tasks / Pipeline for Image Data Processing

siromer
6 min readApr 1, 2024

For every image classification problem, there are key steps, and building a pipeline for these steps can save time.

  • In this article, I am going to create a pipeline for handling data and training model in PyTorch framework .
Fish Species

This PyTorch pipeline consists of two parts. In the first part, I will create the dataset, and in the second part, I will train the model and visualize the results in graphs ( link of second part) .

As I mentioned above, there are main steps for every image classification problem and I generally follow 6 main steps :

  1. Creating Dataset
  2. Visualization of Example Images
  3. Visualization of Class Distribution
  4. Create functions for training model ( second part )
  5. Create Model ( second part )
  6. Train Model ( second part )

1. Creating Dataset

For training a deep learning model you need data . You can create your own datasets by scraping internet , or you can download custom datasets in websites such as kaggle . There are so many datasets in kaggle, and probably you can find one for your purpose .

After obtaining data you can not just use it in your models . There is a format that PyTorch expects . Actually every framework (for example Tensorflow) expects data in different formats and you need to convert your data to that format. Pytorch provides highly effective functions for preparing your data, and they are remarkably easy to use .

I downloaded a fish dataset from kaggle , I will use this dataset throughout this article.

When you download dataset from internet , probably it is not going to ready to use (In my opinion, 90%) . Author may collect datasets in one folder , or they may split dataset into subfolders with classes , or the labels can be stored in a .csv file. You need to adjust that data . PyTorch works well with the format below especially if you are going to use torchvision.datasets.ImageFolder :

Dataset Folders (image2)

I am going to use torchvision.datasets.ImageFolder, so I have adjusted my data to the format mentioned above. It is little bit off-topic , therefore you can check my github repository for this process . I split the data into a 70-30 ratio. After splitting we can start to create dataset in the format of PyTorch expects.

PyTorch provides two data primitives:

  • torch.utils.data.DataLoader
  • torch.utils.data.Dataset

→ torch.utils.data.Dataset : it allows you to use pre-loaded datasets as well as your own data. It stores the samples and their corresponding labels.

→ torch.utils.data.DataLoader : it wraps an iterable around the torch.utils.data.Dataset to enable easy access to the samples.

So lets start to create dataset for PyTorch training models.

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

  • 1.1 : Create transformation objects

Before using datasets.ImageFolder , I will create tranformation objects because they will help to resize and normalize images.

from torchvision import  transforms

"""
with transforms you can resize,normalize images,or create augmented datasets
here , first I resize images and turn images to torch.Tensor
"""

train_transform = transforms.Compose([
# Resize Image
transforms.Resize(size=(180, 180)),
# Turn the image into a torch.Tensor , it transforms the image to a tensor with range [0,1]. I t implies some kind of normalization
transforms.ToTensor()
])

validation_transform = transforms.Compose([
transforms.Resize(size=(180, 180)),
transforms.ToTensor()
])

train_transform,validation_transform
  • 1.2 : Create Dataset From Folder (torchvision.datasets.ImageFolder)

ImageFolder is a generic data loader where the images are arranged in a format similar to the one shown in image2 (check second image).

# Use ImageFolder to create datasets
from torchvision import datasets

train_dir= "../Datasets/Fish_Dataset2/train" # path to the train folder
validation_dir= "../Datasets/Fish_Dataset2/validation" # path to the validation folder

train_data = datasets.ImageFolder(root=train_dir,
transform=train_transform)

validation_data = datasets.ImageFolder(root=validation_dir,
transform=validation_transform)

print(f"Train data:\n{train_data}\n\nValidation data:\n{validation_data}")
Output
  • 1.3 : Create Iterable Dataset for Training (torch.utils.data.DataLoader)

Iterable Dataset : An iterable dataset allows you to iterate over its elements in batches during training. The DataLoader manages batch creation, shuffling, and parallel data loading based on the specified parameters. When training model , it is not a good approach to update parameters with just a single image pass. Instead , we create batches of images, and parameters are updated after processing one batch.

from torch.utils.data import DataLoader

# I created train_data and validation_data with datasets.ImageFolder

train_set = DataLoader(dataset=train_data,
batch_size=16, # how many samples per batch?
num_workers=1, # how many subprocesses to use for data loading? (higher = more)
shuffle=True) # shuffle the data?

validation_set = DataLoader(dataset=validation_data,
batch_size=16,
num_workers=1,
shuffle=False) # dont usually need to shuffle testing data

train_set,validation_set
Output

→ train_set and validation_set can be used for training

2. Visualization of Example Images

It is good practice to see some example images from your dataset . Particularly when using an augmented dataset, it is helpful to observe some example augmented images because sometimes the augmentation may not produce the desired results . By observing some examples you can adjust the augmentation parameters . I am going to use Matplotlib for visualization

import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision.transforms.functional import to_pil_image

# I create train_data above , and I will use it here
label_dict = {y: x for x, y in train_data.class_to_idx.items()}

# Define a function to display images
def show_images(images, labels):
plt.figure(figsize=(12, 8))
for i in range(len(images)):
plt.subplot(4, 4, i + 1)
image = to_pil_image(images[i]) # Convert tensor to PIL Image
plt.imshow(image)
plt.title(label_dict[labels[i].item()]) # Convert numerical label to string label
plt.axis('off')
plt.show()

# Get the first batch
for batch_idx, (images, labels) in enumerate(train_set):
if batch_idx == 0: # Only process the first batch
show_images(images, labels)
break
Output

3. Visualization of Class Distribution

Balancing a dataset is a crucial step because it helps prevent the model from becoming biased towards one class. If you have an imbalanced dataset, the results may not be satisfactory. Visualizing the distribution can be useful for determining whether the dataset is balanced or imbalanced.

# path to train and validation sets
train_dir = train_dir
validation_dir = validation_dir

# calculate distributions in train set and save them to dictionary
train_class_counts = {}
for class_folder in os.listdir(train_dir):
class_path = os.path.join(train_dir, class_folder)
if os.path.isdir(class_path):
num_images = len(os.listdir(class_path))
train_class_counts[class_folder] = num_images

# calculate distributions in validation set and save them to dictionary
validation_class_counts = {}
for class_folder in os.listdir(validation_dir):
class_path = os.path.join(validation_dir, class_folder)
if os.path.isdir(class_path):
num_images = len(os.listdir(class_path))
validation_class_counts[class_folder] = num_images
import matplotlib.pyplot as plt


plt.figure(figsize=(15, 6))

# plot for train set
plt.subplot(1, 2, 1)
plt.bar(train_class_counts.keys(), train_class_counts.values())
plt.title('Training set Distribution')
plt.xlabel('Classes')
plt.ylabel('Sample Numbers')
plt.xticks(rotation=45)

# plot for validations set
plt.subplot(1, 2, 2)
plt.bar(validation_class_counts.keys(), validation_class_counts.values())
plt.title('Validation set Distribution')
plt.xlabel('Classes')
plt.ylabel('Sample Numbers')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()
Visualization of Class Distribution

--

--