A Practical Guide to Multi-Class Image Classification using MMPreTrain

Javad Rezaie (PhD)
10 min readAug 29, 2023

--

Background Image by Gerd Altmann from Pixabay

1. Introduction

Image classification is a fundamental task in computer vision that involves categorizing images into predefined classes. Deep learning techniques, especially convolutional neural networks (CNNs), have revolutionized image classification by achieving impressive accuracy on various datasets. In this tutorial, we will explore how to perform multi-class image classification using the MMPreTrain library.

2. MMPreTrain

MMPreTrain is an open-source pre-training toolbox based on PyTorch. It is a part of the OpenMMLab project. It provides multiple powerful pre-trained backbones and supports different pre-training strategies. MMPreTrain originated from the famous open-source projects MMClassification and MMSelfSup, and is developed with many exciting new features.

The pre-training stage is essential for vision recognition currently. With the rich and strong pre-trained models, we are currently capable of improving various downstream vision tasks.

  • Supports multiple pre-training strategies, including supervised pre-training, self-supervised pre-training, and semi-supervised pre-training.
  • Provides multiple powerful pre-trained backbones, including ResNet, ResNeXt, ViT, and Swin Transformer.
  • Easy to use and extend. MMPreTrain is built on top of MMCV, which provides a unified and comprehensive infrastructure for computer vision.
  • Well-documented. The documentation of MMPreTrain is clear and concise, making it easy for users to get started.

Specially, for the image classification tasks:

  • MMPreTrain supports multiple image classification datasets, including ImageNet, CIFAR-10, and MS COCO.
  • MMPreTrain can be used to fine-tune the pre-trained models on a specific dataset to improve the performance on that dataset.
  • MMPreTrain can be used to train new image classification models from scratch.
  • MMPreTrain can be used to transfer learning, which is the process of using a pre-trained model as a starting point for training a new model on a different task.

To use MMPreTrain for multi-class image classification, you can follow these steps:

  1. Create a conda environment and install the necessary packages.
  2. Download the dataset.
  3. Import the predefined configuration.
  4. Modify the configuration to specify the number of classes and other parameters.
  5. Configure the model, data loaders, validation evaluator, and optimization wrapper.
  6. Train/Fine-tune the model on the dataset.
  7. Evaluate the model on the validation set.

3. Hands-on Implementation

Source code: Original codes can be downloaded from GitHub.

Trained Model: The trained model uploaded on Hugging Face. It is available to test and/or download.

3.1. Setup

Creating a Conda environment and activate it

conda create --name openmmlab python=3.10 -y
conda activate openmmlab

Install the necessary packages

conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia
sudo reboot
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
pip install -U openmim && mim install -e .

3.2. Downloading Dataset

Stanford Cars data set can be downloaded from Kaggle.

3.3. Importing Predefined Configurations

Some popular pre-trained CNN architectures include ResNet, ResNeXt, ViT, and Swin Transformer. You could also choose to use a custom CNN architecture, such as the MobileNetV2 architecture. In this tutorial, we’ll use the EfficientNetV2_b0 architecture for image classification. MMPreTrain provides predefined configurations for various models. Import the EfficientNetV2_b0 configuration:

_base_ = [
'mmpretrain::efficientnet_v2/efficientnetv2-b0_8xb32_in1k.py'
]

3.4. Loading pretrained model (if necessary)

  • You can load a pretrained model from a URL.
  • You can load a pretrained model from a checkpoint file.
  • You can load a pretrained model from a local file.

Here we load the pretrained model from a URL:

load_from = "https://download.openmmlab.com/mmclassification/v0/efficientnetv2/efficientnetv2-b0_3rdparty_in1k_20221221-9ef6e736.pth"

3.5. Update the Model Configuration

  • Number of classes: The number of classes will depend on the dataset that you are using. For example, if you are using the Stanford Cars dataset, the number of classes would be 196.
num_classes = 196
data_preprocessor = dict(
num_classes=num_classes)

model = dict(
head=dict(
num_classes=num_classes,
))

3.6. Update the Optimizer Configuration

  • Learning rate: You can adjust the learning rate up or down depending on how well the model is learning.
  • Optimizer: The AdamW optimizer is a good choice for image classification tasks. You could also use the SGD optimizer, but you may need to adjust the learning rate more frequently.
  • Scheduler: The StepLR scheduler is a good choice for image classification tasks. You could also use the ReduceLROnPlateau scheduler, but this may be more effective if the model is prone to overfitting.
  • Mixed precision: Mixed precision can be enabled if you have a GPU that supports it. This can improve the performance of the model without sacrificing accuracy.
warmup_epochs = 10
base_lr = 5e-4

optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.001),
# specific to vit pretrain
paramwise_cfg=dict(custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}),
)

param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
end=warmup_epochs,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
eta_min=1e-5,
by_epoch=True,
begin=warmup_epochs)
]

3.7. Update the Data Loaders Configuration

  • Batch size: A good starting point for the batch size is 32. You can adjust the batch size up or down depending on the amount of memory that you have available.
  • Image size: A good starting point for the image size is 224×224. You can adjust the image size up or down depending on the computational resources that you have available.
  • Augmentation: Some popular augmentation techniques include random cropping, random flipping, and color jittering. You can use a combination of these techniques to improve the performance of the model.
# Data set can be downloaded from https://www.kaggle.com/datasets/jutrera/stanford-car-dataset-by-classes-folder
# Change the data_root path in efficientnetv2_b0_config.py to the location where you downloaded the data:

data_root = "/path/to/Datasets/Stanford_Cars_by_class_folder/car_data/car_data/"
train_image_folder = "train"
val_image_folder = "test"
IMAGENET_CATEGORIES = ["AM General Hummer SUV 2000", "Acura RL Sedan 2012", "Acura TL Sedan 2012", "Acura TL Type-S 2008", "Acura TSX Sedan 2012", "Acura Integra Type R 2001", "Acura ZDX Hatchback 2012", "Aston Martin V8 Vantage Convertible 2012", "Aston Martin V8 Vantage Coupe 2012", "Aston Martin Virage Convertible 2012", "Aston Martin Virage Coupe 2012", "Audi RS 4 Convertible 2008", "Audi A5 Coupe 2012", "Audi TTS Coupe 2012", "Audi R8 Coupe 2012", "Audi V8 Sedan 1994", "Audi 100 Sedan 1994", "Audi 100 Wagon 1994", "Audi TT Hatchback 2011", "Audi S6 Sedan 2011", "Audi S5 Convertible 2012", "Audi S5 Coupe 2012", "Audi S4 Sedan 2012", "Audi S4 Sedan 2007", "Audi TT RS Coupe 2012", "BMW ActiveHybrid 5 Sedan 2012", "BMW 1 Series Convertible 2012", "BMW 1 Series Coupe 2012", "BMW 3 Series Sedan 2012", "BMW 3 Series Wagon 2012", "BMW 6 Series Convertible 2007", "BMW X5 SUV 2007", "BMW X6 SUV 2012", "BMW M3 Coupe 2012", "BMW M5 Sedan 2010", "BMW M6 Convertible 2010", "BMW X3 SUV 2012", "BMW Z4 Convertible 2012", "Bentley Continental Supersports Conv. Convertible 2012", "Bentley Arnage Sedan 2009", "Bentley Mulsanne Sedan 2011", "Bentley Continental GT Coupe 2012", "Bentley Continental GT Coupe 2007", "Bentley Continental Flying Spur Sedan 2007", "Bugatti Veyron 16.4 Convertible 2009", "Bugatti Veyron 16.4 Coupe 2009", "Buick Regal GS 2012", "Buick Rainier SUV 2007", "Buick Verano Sedan 2012", "Buick Enclave SUV 2012", "Cadillac CTS-V Sedan 2012", "Cadillac SRX SUV 2012", "Cadillac Escalade EXT Crew Cab 2007", "Chevrolet Silverado 1500 Hybrid Crew Cab 2012", "Chevrolet Corvette Convertible 2012", "Chevrolet Corvette ZR1 2012", "Chevrolet Corvette Ron Fellows Edition Z06 2007", "Chevrolet Traverse SUV 2012", "Chevrolet Camaro Convertible 2012", "Chevrolet HHR SS 2010", "Chevrolet Impala Sedan 2007", "Chevrolet Tahoe Hybrid SUV 2012", "Chevrolet Sonic Sedan 2012", "Chevrolet Express Cargo Van 2007", "Chevrolet Avalanche Crew Cab 2012", "Chevrolet Cobalt SS 2010", "Chevrolet Malibu Hybrid Sedan 2010", "Chevrolet TrailBlazer SS 2009", "Chevrolet Silverado 2500HD Regular Cab 2012", "Chevrolet Silverado 1500 Classic Extended Cab 2007", "Chevrolet Express Van 2007", "Chevrolet Monte Carlo Coupe 2007", "Chevrolet Malibu Sedan 2007", "Chevrolet Silverado 1500 Extended Cab 2012", "Chevrolet Silverado 1500 Regular Cab 2012", "Chrysler Aspen SUV 2009", "Chrysler Sebring Convertible 2010", "Chrysler Town and Country Minivan 2012", "Chrysler 300 SRT-8 2010", "Chrysler Crossfire Convertible 2008", "Chrysler PT Cruiser Convertible 2008", "Daewoo Nubira Wagon 2002", "Dodge Caliber Wagon 2012", "Dodge Caliber Wagon 2007", "Dodge Caravan Minivan 1997", "Dodge Ram Pickup 3500 Crew Cab 2010", "Dodge Ram Pickup 3500 Quad Cab 2009", "Dodge Sprinter Cargo Van 2009", "Dodge Journey SUV 2012", "Dodge Dakota Crew Cab 2010", "Dodge Dakota Club Cab 2007", "Dodge Magnum Wagon 2008", "Dodge Challenger SRT8 2011", "Dodge Durango SUV 2012", "Dodge Durango SUV 2007", "Dodge Charger Sedan 2012", "Dodge Charger SRT-8 2009", "Eagle Talon Hatchback 1998", "FIAT 500 Abarth 2012", "FIAT 500 Convertible 2012", "Ferrari FF Coupe 2012", "Ferrari California Convertible 2012", "Ferrari 458 Italia Convertible 2012", "Ferrari 458 Italia Coupe 2012", "Fisker Karma Sedan 2012", "Ford F-450 Super Duty Crew Cab 2012", "Ford Mustang Convertible 2007", "Ford Freestar Minivan 2007", "Ford Expedition EL SUV 2009", "Ford Edge SUV 2012", "Ford Ranger SuperCab 2011", "Ford GT Coupe 2006", "Ford F-150 Regular Cab 2012", "Ford F-150 Regular Cab 2007", "Ford Focus Sedan 2007", "Ford E-Series Wagon Van 2012", "Ford Fiesta Sedan 2012", "GMC Terrain SUV 2012", "GMC Savana Van 2012", "GMC Yukon Hybrid SUV 2012", "GMC Acadia SUV 2012", "GMC Canyon Extended Cab 2012", "Geo Metro Convertible 1993", "HUMMER H3T Crew Cab 2010", "HUMMER H2 SUT Crew Cab 2009", "Honda Odyssey Minivan 2012", "Honda Odyssey Minivan 2007", "Honda Accord Coupe 2012", "Honda Accord Sedan 2012", "Hyundai Veloster Hatchback 2012", "Hyundai Santa Fe SUV 2012", "Hyundai Tucson SUV 2012", "Hyundai Veracruz SUV 2012", "Hyundai Sonata Hybrid Sedan 2012", "Hyundai Elantra Sedan 2007", "Hyundai Accent Sedan 2012", "Hyundai Genesis Sedan 2012", "Hyundai Sonata Sedan 2012", "Hyundai Elantra Touring Hatchback 2012", "Hyundai Azera Sedan 2012", "Infiniti G Coupe IPL 2012", "Infiniti QX56 SUV 2011", "Isuzu Ascender SUV 2008", "Jaguar XK XKR 2012", "Jeep Patriot SUV 2012", "Jeep Wrangler SUV 2012", "Jeep Liberty SUV 2012", "Jeep Grand Cherokee SUV 2012", "Jeep Compass SUV 2012", "Lamborghini Reventon Coupe 2008", "Lamborghini Aventador Coupe 2012", "Lamborghini Gallardo LP 570-4 Superleggera 2012", "Lamborghini Diablo Coupe 2001", "Land Rover Range Rover SUV 2012", "Land Rover LR2 SUV 2012", "Lincoln Town Car Sedan 2011", "MINI Cooper Roadster Convertible 2012", "Maybach Landaulet Convertible 2012", "Mazda Tribute SUV 2011", "McLaren MP4-12C Coupe 2012", "Mercedes-Benz 300-Class Convertible 1993", "Mercedes-Benz C-Class Sedan 2012", "Mercedes-Benz SL-Class Coupe 2009", "Mercedes-Benz E-Class Sedan 2012", "Mercedes-Benz S-Class Sedan 2012", "Mercedes-Benz Sprinter Van 2012", "Mitsubishi Lancer Sedan 2012", "Nissan Leaf Hatchback 2012", "Nissan NV Passenger Van 2012", "Nissan Juke Hatchback 2012", "Nissan 240SX Coupe 1998", "Plymouth Neon Coupe 1999", "Porsche Panamera Sedan 2012", "Ram C/V Cargo Van Minivan 2012", "Rolls-Royce Phantom Drophead Coupe Convertible 2012", "Rolls-Royce Ghost Sedan 2012", "Rolls-Royce Phantom Sedan 2012", "Scion xD Hatchback 2012", "Spyker C8 Convertible 2009", "Spyker C8 Coupe 2009", "Suzuki Aerio Sedan 2007", "Suzuki Kizashi Sedan 2012", "Suzuki SX4 Hatchback 2012", "Suzuki SX4 Sedan 2012", "Tesla Model S Sedan 2012", "Toyota Sequoia SUV 2012", "Toyota Camry Sedan 2012", "Toyota Corolla Sedan 2012", "Toyota 4Runner SUV 2012", "Volkswagen Golf Hatchback 2012", "Volkswagen Golf Hatchback 1991", "Volkswagen Beetle Hatchback 2012", "Volvo C30 Hatchback 2012", "Volvo 240 Sedan 1993", "Volvo XC90 SUV 2007", "smart fortwo Convertible 2012"]
METAINFO = {'classes': IMAGENET_CATEGORIES}

train_dataloader = dict(
dataset=dict(
metainfo=METAINFO,
data_root=data_root,
data_prefix=train_image_folder
)
)

val_dataloader = dict(
dataset=dict(
metainfo=METAINFO,
data_root=data_root,
data_prefix=val_image_folder
)
)

Remember: Change the data_root path in efficientnetv2_b0_config.py to the location where you downloaded the data (train and test folders are located under data_root ).

3.8. Update the Validation Evaluator Configuration

  • Evaluator: The accuracy metric is a good choice for evaluating the performance of an image classification model. You could also use the loss, the precision, or the recall metrics.
  • Interval: A good starting point for the interval is 1 epoch. You can adjust the interval up or down depending on how often you want to evaluate the model.
val_evaluator = dict(type='Accuracy', topk=(1, 5))
test_evaluator = val_evaluator

3.9. Update the Training Configuration

  • Number of epochs: A good starting point for the number of epochs is 30. You can adjust the number of epochs up or down depending on how well the model is learning.
  • Early stopping: Early stopping can be enabled if you want to prevent the model from overfitting.

max_epochs = 30
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# local path to saving the models and logs
work_dir = "./out"

# configure default hooks
default_hooks = dict(

# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', max_keep_ckpts=1),
)

3.10. Load the configuration and train the model using mmengine

Create a separate Pyhton script for importing the main configuration and starting the training process (main_train_mmengine.py).

from mmengine.config import Config
from mmengine.runner import Runner
import argparse

def main(args):
config = Config.fromfile(args.config_path)
config.launcher = "pytorch"
runner = Runner.from_cfg(config)
runner.train()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Get Config Path.')
parser.add_argument('config_path', type=str, help='path to the config file')
args = parser.parse_args()
main(args)

3.11. Run training from terminal

Start training by feeding the configuration (efficientnetv2_b0_config.py) into the main mmengine running script (main_train_mmengine.py).

torchrun --nnodes 1 --nproc_per_node=3 main_train_mmengine.py efficientnetv2_b0_config.py

3.12. Log Analysis (Visualizing training/validation results)

python mmpretrain/tools/analysis_tools/analyze_logs.py plot_curve ./out/path/to/scalars.json --keys accuracy/top1 accuracy/top5 --legend top1 top5 --out accuracy.jpg --title EfficientNetV2_b0
Validation Accuracy.

3.13. Log Analysis (Plot hyper-parameter scheduler of the optimizer, learning rate)

python mmpretrain/tools/analysis_tools/analyze_logs.py plot_curve ./out/path/to/scalars.json --keys lr --legend lr --out lr.jpg --title EfficientNetV2_b0
Optimizer Learning Rate.

3.14. Model Complexity Analysis (Get the FLOPs and params)

python mmpretrain/tools/analysis_tools/get_flops.py /path/to/configuration/efficientnetv2_b0_config.py
Model Complexity Analysis

For more details about image classification and practical implementation see Image Classification with Deep Neural Networks.

4. Conclusion

Image classification with deep neural networks has transformed how we interact with visual data. From its historical roots to the present breakthroughs, DNNs have reshaped industries and opened new possibilities. As technology advances, the responsible development and deployment of image classification systems will play a pivotal role in shaping our AI-driven future.

References

Source code: Original codes can be downloaded from GitHub.

Trained Model: The trained model uploaded on Hugging Face. It is available to test and/or download.

Image Classification with Deep Neural Networks

OpenMMLab. MMPreTrain: OpenMMLab Pre-training Toolbox and Benchmark.

Continue Reading

A Practical Guide to Object Detection using MMDetection with Docker

Disclaimer

This project is intended for educational purposes only. Any use of this project for real-world applications should be done with caution and proper consultation with relevant experts.

--

--

Javad Rezaie (PhD)

👋 Hello, I'm Javad Rezaie 🚀 📚 PhD in Mathematics | 🎓 Master's in Electrical Engineering 🔍 Data Scientist | 🤖 Machine Learning Engineer | ⚙️ Automation