Creating custom Datasets and Dataloaders with Pytorch

Bivek Adhikari
Bivek Adhikari
Published in
3 min readAug 31, 2020

This post will discuss how to create custom image datasets and dataloaders in Pytorch. Datasets that are prepackaged with Pytorch can be directly loaded by using the torchvision.datasets module. The following code will download the MNIST dataset and load it.

mnist_dataset = torchvision.datasets.MNIST(root="./data",   download=True, train=True, transform=transforms.ToTensor())

Similarly, for image datasets with input data organized within separate folders based on their parent labels, the ImageFolder class within torchvision.dataset can be used.

dataset = torchvision.datasets.ImageFolder(root="path/to/dataset", transform=transforms.ToTensor())

However, if your image files and labels are structured differently, you can create your custom datasets by subclassing and overriding the Dataset class in torchvision.utils.data . For this tutorial, we are going to use this facial landmarks dataset. We are going to use this dataset for a simple male/female image classification task to demonstrate how to create custom datasets for unusual datasets. So, after downloading and extracting the zip file, we can go ahead and delete the "face_landmarks.csv" and "create_landmark_dataset.py" files inside the folder.

I have created a custom CSV file containing the labels for the images inside the faces folder which can be downloaded from here. The “faces.csv” file contains individual labels (“0” for female and “1” for male) for each image in our dataset. The faces folder contains our image data and the faces.csv file contains our image labels. After downloading all the files, our folder structure is going to look like this.

| — project
| | — faces
| | — faces.csv
| | — demo.py

Now, we can go ahead and create our custom Pytorch dataset. We will create a python file (“demo.py”) in the same folder and start by importing the required libraries.

import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import PIL
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

Now, we will inherit from the Dataset class from torch.utils.data and override the __init__, __len__ and __getitem__ methods to make it fit our dataset.

In the constructor method, we declare the variables for our project root folder, the image files directory and the CSV labels file. We want to extract the second column from our labels file using iloc[:, 1] because it contains the labels to our data.

class FacesDataset(Dataset):

def __init__(self, root, image_dir, csv_file, transform=None):
self.root = root
self.image_dir = image_dir
self.image_files = os.listdir(image_dir)
self.data = pd.read_csv(csv_file).iloc[:, 1]
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, index):
image_name = os.path.join(self.image_dir, self.image_files[index])
image = PIL.Image.open(image_name)
label = self.data[index]
if self.transform:
image = self.transform(image)
return (image, label)

The __len__ method returns the length of our dataset. And, the __getitem__ method extracts an image, label pair from our dataset. The image is first converted to a PIL Image so that we can apply transformations to it later.Â

The next step is to point to our project folder, the image directory and the CSV file. We also define the transformation function to resize our images and convert them to Pytorch tensors.

root = Path(os.getcwd())
image_dir = root/'faces'
csv_file = root/'faces.csv'
transform_img = transforms.Compose([
transforms.Resize(80),
transforms.CenterCrop(80),
transforms.ToTensor()
])

Now, we can pass these values to our FacesDataset class and load our dataset. This dataset can be split into a training set and a testing set by using random_split.

dset = FacesDataset(root, image_dir, csv_file, transform= transform_img)

train_dataset, test_dataset = torch.utils.data.random_split(dset, [50, 18])

We can verify that the dataset is loaded properly by displaying a sample image, label pair from our dataset.

def show_image(image, label, dataset):
print(f"Label: {label}")
plt.imshow(image.permute(1,2,0))
plt.show()

show_image(*train_dataset[0], train_dataset)
Label: 1
png

--

--