Image Classification with ResNet, ConvNeXt using PyTorch

Kaustav Mandal
exemplifyML.ai
Published in
7 min readMay 28, 2022

Synopsis: Image classification with ResNet, ConvNeXt along with data augmentation techniques on the Food 101 dataset

Photo by Brooke Lark on Unsplash

A quick walk-through on using CNN models for image classification and fine tune them for better accuracy.

Dataset used: Food 101

Libraries used: pytorch, torchvision, cudatoolkit, tensorboard

Preprocessing Data:

Number of categories: 101

Training set images per category : 750

Test set images per category: 250

For the purposes of this tutorial, we are going to breakdown the training set of images into a train set and validation set in a 80:20 ratio.

Training set per category: 600

Validation set per category: 150

Loading Data:

We are setting up a Distributed Data Loader for leveraging multiple GPUs.

Additional details on setting up pytorch for using multiple GPUs can be found here.

Training:

We are using a pre-trained model, in order to leverage the weights that have been trained on a large dataset.

As we would like to make it more relevant to our food dataset, we will swap out the fully connected layer and retrain those on our dataset.

Custom Fully Connected Layer(FC) Definition for ResNets:

n_inputs = model.fc.in_features
n_outputs = 101

sequential_layers = nn.Sequential(
nn.Linear(n_inputs, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048, n_outputs),
nn.LogSoftmax(dim=1)
)
model.fc = sequential_layers
for param in model.fc.parameters():
param.requires_grad = True

Model 1 : ResNet34(pretrained = True) with training the FC layer

Trainable Parameters:

+--------------------+------------+
| Modules | Parameters |
+--------------------+------------+
| module.fc.0.weight | 1048576 |
| module.fc.0.bias | 2048 |
| module.fc.1.weight | 2048 |
| module.fc.1.bias | 2048 |
| module.fc.4.weight | 4194304 |
| module.fc.4.bias | 2048 |
| module.fc.5.weight | 2048 |
| module.fc.5.bias | 2048 |
| module.fc.7.weight | 206848 |
| module.fc.7.bias | 101 |
+--------------------+------------+
Total Trainable Params: 5462117

Result:

exp_no:428 | Test Sample Size: 6313 | Rank: 0, Test Loss: 1.3087375105494234, Test Accuracy: 0.6513543481704419

We should be able to increase our accuracy more than 65%. We could look at the following options for increasing our accuracy.

  1. Increase number of parameters
  2. Increase size of the dataset
  3. Try another classification model with a different architecture

Effect of Increasing the number of parameters trained

Model 2: ResNet101(pretrained = True) with training the FC layer

Trainable Parameters:

+--------------------+------------+
| Modules | Parameters |
+--------------------+------------+
| module.fc.0.weight | 4194304 |
| .... same as above ..... |
+--------------------+------------+
Total Trainable Params: 8607845

Result:

exp_no:426 | Test Sample Size: 6313 | Rank: 0, Test Loss: 1.034651347919744, Test Accuracy: 0.7191509583399335

Great, we improved our accuracy from 65% to 72% with the same dataset by increasing the size of our parameters — Resnet34 to Resnet101

Testing different architecture models(ConvNeXt):

Model 3: ConvNeXt-B(pretrained = True) with training the classification layer

This ConvNeXt model follows the design of a vision Transformer(Swin), without attention based modules.

Block designs — b/t Swin Transformers, ResNets, ConvNeXt, arXiv:2201.03545

Trainable Parameters:

Classifier Layer definition for ConvNeXt:

n_inputs = None
for name, child in model.named_children():
if name == 'classifier':
for sub_name, sub_child in child.named_children():
if sub_name == '2':
n_inputs = sub_child.in_features
n_outputs = 101

sequential_layers = nn.Sequential(
LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(n_inputs, 2048, bias=True),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(2048, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048, n_outputs),
nn.LogSoftmax(dim=1)
)
model.classifier = sequential_layers

Result:

exp_no:420 | Test Sample Size: 6313 | Rank: 0, Test Loss: 0.604434494471448, Test Accuracy: 0.8300332646919056

We improved our model accuracy from 72% to 83% using a different derivative model based on the original ResNet architecture.

Data Augmentation:

In this section, we will focus on data augmentation techniques.

PyTorch Vision provides support for different types of image transforms which we can leverage for augmenting images for training.

We will be primarily using resizing , centering and normalizing transforms.

As the data gets loaded via the distributed data loader, the images will be transformed/augmented on the fly.

Minimal Transform

train_transform = T.Compose([
T.Resize((256,256)),
T.CenterCrop((224,224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

Expanding upon the basic transforms and using auto augmentation transforms(PyTorch Auto Augmentation Transforms):

T.Compose([
T.RandomOrder((
T.AutoAugment(policy=AutoAugmentPolicy.IMAGENET, interpolation=InterpolationMode.BILINEAR),
T.RandAugment(),
T.TrivialAugmentWide()
)),
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])),

For increasing our dataset size further, we can take the resize/normalize transform as one dataset, take the augmented images as another dataset and use it in our dataloader together as illustrated below.

ConcatDataSet class provides functionality for fusing together different datasets and passing them to the distributed dataloader as one big data set.

Additionally, we will be using the FiveCrop transform which crops an image into four corners and a central image. This transform will return as an extra dimension on the tensor, as opposed to the other transforms.

FiveCrop transform shown w/o batch dim

For using a five crop transform, we have 2 approaches.

Approach 1:

Using a custom collate function in the data loader itself, where the collate function would take the five crop output and stack it along with the other augmented images.

In this approach, we have better speed /efficiency at the cost of having GPUs with more memory, which is required for handling 1:5 images. If memory becomes a bottleneck, we can reduce the batch size.

Build out the DataLoader with the transform function as illustrated below:

# inside the ImageFolder data loader transform parameter,
# the FIS.stack_images is custom function to stack list of image tensors into 4 D tensor via torch.stack, its a custom function to work around pickle restrictions
transform=T.Compose([
T.Resize((256, 256)),
T.FiveCrop((100)),
T.Lambda(FIS.stack_images),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

]))

Add custom Collate Function for Dataloader

Assign the custom collate function to the dataloader

train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
num_workers=4,
sampler=dist_train_samples,
pin_memory=True,
collate_fn=collate_4D_3D_tensor
)

Approach 2:

Instead of using a custom collate and using a five crop transform on the fly, we can add a preprocessing step of adding the output of five_crop function as images stored onto the training data-set folder.

In this approach, we can use GPUs with a smaller memory footprint, or large batch sizes, as we are not splitting images into multiple ones.

Additionally, the fivecrop transforms are generated only once per training and can be reused in all the iterations, as opposed to the previous approach where the fivecrop transforms are computed on the fly.

The drawback to this approach is having longer model train times.

For using five_crop as a preprocessing step, we perform the following steps.

a. Use PIL library to load train images

list_of_PIL_fmt_images = []
list_of_PIL_fmt_images.append((Image.open(src_file_path, mode='r', formats=['JPEG'])).convert(pil_image_mode))

b. Convert them into tensors , use five_crop

...
list_of_img_as_tensor = [((T.ToTensor()(pilimg)).to('cuda')) for pilimg in
list_of_PIL_fmt_images]
...
# Stack up the images by adding an extra dim to the tensor
torch_stack_batch_per_proc = torch.stack(list_of_img_as_tensor)
...
# transform the tensors
augmented_five_crop_imgs = transform_fivecrop_seq(
torch_stack_batch_per_proc)

c. Convert them to Numpy ndarray

...
# move transformed augmented images to cpu
augmented_five_crop_imgs = augmented_five_crop_imgs.to('cpu', torch.uint8).numpy()
# for each image save it via the PIL library
Image.fromarray(ndarr)

d. Save the images in the corresponding class/label folder

Note: In this approach, instead of the Compose function, we will be using the nn.Sequential module and add custom classes for stacking , transforming the images.

Ideally, using one liners for ImageStackerPlain and NumpySafeTransform would be great, however I was not able to find a way to use the Lambda transform along with the nn.Sequential class.

transform_fivecrop_seq = nn.Sequential(
BatchImageTransform(),
ImageStackerPlain(),
NumpySafeTransform()
).cuda(rank)

Example definition of a BatchImageTransform class

Now that we have gone over the data augmentation details, here are the results of the training with the augmented data for a pre-trained ConvNeXt-B with a custom classifier layer for the food image classification problem.

Training Results:

Approach 1: Using collate function in data loader

exp_no:438 | Test Sample Size: 6313 | Rank: 0, Test Loss: 0.5250024251036762, Test Accuracy: 0.8509424996039918

Approach 2: Preprocessing FiveCrop and using 1:5 image files for input

exp_no:434 | Test Sample Size: 6313 | Rank: 0, Test Loss: 0.5627518318429455, Test Accuracy: 0.8458735941707588
Image of step by step Val set accuracy, run in tandem during train

Both FiveCrop augmentation approaches worked out, and we bumped our accuracy from 83% to roughly 85% using FiveCrop data augmentation along with other transforms.

Conclusion:

After trying different models, increasing the number of parameters and applying various image augmentation techniques, we were able to increase the classification accuracy of food categories from 64% to 85%.

Other Articles in this series:

References:

  1. PyTorch Documentation
  2. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. (2015). Deep Residual Learning for Image Recognition
  3. Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He. (2016). Aggregated Residual Transformations for Deep Neural Networks
  4. Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. (2022). A ConvNet for the 2020s
  5. Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  6. Dan Hendrycks, Kevin Gimpel. (2016). Gaussian Error Linear Units (GELUs)
  7. PyTorch Community Discussion — Normalization.
  8. PyTorch Community Discussion — Visualize Feature Map.

--

--

Kaustav Mandal
exemplifyML.ai

Software Engineer with an interest in Machine Learning / Data science , ML Ops