Training an image classification model in PyTorch

Madhur Zanwar
Eumentis
Published in
5 min readJan 15, 2024
Image Classification in Action

This article is the first in a series of four articles on building an image classification model in PyTorch and porting it to mobile devices. In this article, we’ll talk about

  • our objective to train a classifier.
  • present Python code to train this classifier using Pytorch

Our use case: the purpose of doing it

Our end goal is to detect an object in the image.

In the agricultural sector, where livestock are raised to provide labor and produce diverse products such as milk, eggs, and leather among others for consumption, machine learning can enhance productivity and optimize operations. Both during training and in real-world use-case scenarios, a good quality dataset with decent image quality and the object to be recognized at a good spot ( nearer to the image ) and without being occluded is necessary. As the application gains popularity among a broad audience and hence is used in different places and environments, capturing good-quality images might not be possible. As a result, passing a poor-quality image or an image in which the object to be detected is not clear might give random/bad results and hence deprive the model of its predictive power. This is where image classification can be employed to filter out unwanted images before they are passed to the object detection model for inference. Some cases in which we would not want the image to be passed to the model include :

  1. A blurry image.
  2. Animal is too far.
  3. Occlusion of the primary animal.
  4. A different animal than the one for which the model has been built.

In order to start with the process, we first build an image classifier using PyTorch on a web device.

Training an image classification model in PyTorch on a web device

First, we will define the path to our dataset and the pre-processing steps we need to carry out on the dataset. Pre-processing the dataset helps to increase the accuracy and prevent overfitting of the model.

# path to dataset
IMAGE_DATASET_PATH = Path to your dataset.
# training dataset path
TRAIN_DATASET_PATH = IMAGE_DATASET_PATH + "train"
# validation dataset path
VAL_DATASET_PATH = IMAGE_DATASET_PATH + "validation"

# Defining our data augmentation pipeline
# feature extraction batch size
FEATURE_EXTRACTION_BATCH_SIZE = 256
# fintune batch size
FINETUNE_BATCH_SIZE = 64
# prediction batch size
PRED_BATCH_SIZE = 4
# number of epochs
EPOCHS = 500
LR = 0.0001
LR_FINETUNE = 0.0005
# values to normalize the tensor
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# resize the image to IMAGE_SIZE
IMAGE_SIZE = 224
resize = transforms.RandomResizedCrop(IMAGE_SIZE)
hFlip = transforms.RandomHorizontalFlip(p=0.25)
normalize = transforms.Normalize(mean=MEAN, std=STD)
# apply the above transformations
trainTransform = transforms.Compose([resize, hFlip, transforms.ToTensor()])
valTransform = transforms.Compose([resize, transforms.ToTensor()])

We load the dataset and apply the transformations with the help of DataLoader in PyTorch. While training, if you feel the speed to be very low, feel free to play around with the num_workers argument, which helps in defining the number of processes that generate batches in parallel, which in turn speeds up the process.

def get_dataloader(rootDir, transforms, batchSize, shuffle):
# create a dataset and use it to create a data loader
ds = datasets.ImageFolder(root=rootDir,transform=transforms)
# loading the dataset with the help of DataLoader by pasing the required arguments.
loader = DataLoader(ds, batch_size=batchSize, shuffle=shuffle, num_workers=os.cpu_count(),
pin_memory=True if DEVICE == "cuda" else False)
# return a tuple of the dataset and the data loader
return (ds, loader)

# calling the function to load the training dataset
(trainDS, trainLoader) = get_dataloader(TRAIN_DATASET_PATH, transforms=trainTransform,
batchSize=FEATURE_EXTRACTION_BATCH_SIZE, shuffle=True)
# calling the function to load the validation dataset
(valDS, valLoader) = get_dataloader(VAL_DATASET_PATH, transforms=valTransform,
batchSize=FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)

Having successfully loaded our dataset and applied the desired transformations, we now define the base Efficientnet_b0 model. We do not want the layers of efficientnet_b0 to be re-trained, so we freeze the layers of the base model, on top of which we train our custom classification layer.

# defining efficientnet_b0 model
model = efficientnet_b0(pretrained=True)
# setting its parameters to non-trainable (by default they are trainable)
for param in model.parameters():
param.requires_grad = False
# Building on top of the base efficientnet_b0 model
model.classifier[1] = nn.Sequential(nn.Linear(in_features=1280, out_features=len(number of classes you have in your dataset)), nn.Softmax(dim=1))
# transfer model to cuda if available.
model = model.to(DEVICE)

# initialize loss function and optimizer (notice that we are only
# providing the parameters of the classification top to our optimizer)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=LR)

Now that we have the model structure, let’s outline the steps for training the model.

# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDS) // FEATURE_EXTRACTION_BATCH_SIZE

# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
best_accuracy = 0
for e in tqdm(range(EPOCHS)):
epoch_startTime = time.time()
# set the model in training mode
model.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalValLoss = 0
# initialize the number of correct predictions in the training and validation step
trainCorrect = 0
valCorrect = 0
# loop over the training set
for (i, (x, y)) in enumerate(trainLoader):
loop_startTime = time.time()
# send the input to the device
(x, y) = (x.to(DEVICE), y.to(DEVICE))
# perform a forward pass and calculate the training loss
pred = model(x)
loss = lossFunc(pred, y)
# calculate the gradients
loss.backward()
# check if we are updating the model parameters and if so update them,
# and zero out the previously accumulated gradients
if (i + 2) % 2 == 0:
opt.step()
opt.zero_grad()
# add the loss to the total training loss so far and
# calculate the number of correct predictions
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()
loop_endTime = time.time()
# print("Loop %s finished in:{} seconds", (format(loop_endTime - loop_startTime), str(i)))
# switch off autograd
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# loop over the validation set
for (x, y) in valLoader:
# send the input to the device
(x, y) = (x.to(DEVICE), y.to(DEVICE))
# make the predictions and calculate the validation loss
pred = model(x)
totalValLoss += lossFunc(pred, y)
# calculate the number of correct predictions
valCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()

We complete the training process here. In order to calculate the training and validation loss and to save the model having the best weights, we follow the below steps:

  # calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
trainCorrect = trainCorrect / len(trainDS)
valCorrect = valCorrect / len(valDS)
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["train_acc"].append(trainCorrect)
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
H["val_acc"].append(valCorrect)
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, EPOCHS))
print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(avgTrainLoss, trainCorrect))
print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(avgValLoss, valCorrect))
epoch_endTime = time.time()
print("Epoch finished in:{} seconds".format(epoch_endTime - epoch_startTime))
######### INTERMITTENT SAVING OF MODEL ###########
# Save model weights every 10 epochs
if (e + 1) % 10 == 0:
torch.save(model.state_dict(), TRAINING_ARTIFACTS_PATH + f"model_weights_epoch_{e + 1}.pt")
# Check if current epoch's val accuracy is better than the best val accuracy so far. If so, save that model
if valCorrect > best_accuracy:
best_accuracy = valCorrect
# Save the best model weights
torch.save(model.state_dict(), TRAINING_ARTIFACTS_PATH + f"model_weights_best_{e + 1}.pt")

# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(endTime - startTime))
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")

Here, we complete the model training. Please proceed to the next section on how to convert the model into a mobile optimized format.

--

--