Almost any Image Classification Problem using PyTorch

PyTorch Logo

This is an experimental setup to build code base for PyTorch. Its main aim is to experiment faster using transfer learning on all available pre-trained models. We will be using the plant seedlings classification dataset for this blog-post. This was hosted as a play-ground competition on Kaggle. More details here.

The following pre-trained models are available on PyTorch

  • resnet18, resnet34, resnet50, resnet101, resnet152
  • squeezenet1_0, squeezenet1_1
  • Alexnet
  • inception_v3
  • Densenet121, Densenet169, Densenet201
  • Vgg11, vgg13, vgg16, vgg19, vgg11_bn. vgg13_bn, vgg16_bn, vgg19_bn

The three cases in Transfer Learning and how to solve them using PyTorch

I have already discussed the intuition behind transfer learning in my previous blog post. So I will just mention them here.

  1. Freezing all the layers except the final one
  2. Freezing the first few layers
  3. Fine-tuning the entire network.

This is very much straight forward in PyTorch if you know how the models are structured and wrapped. All the models used above are written differently. Some use Sequential containers, which contain many layers and some directly contain just the layer. So it is important to check how these models are defined in PyTorch.

ResNet and Inception_V3

As mentioned before there are several Resnets and we can use whichever we need. Since the Imagenet dataset has 1000 layers, We need to change the last layer as per our requirement. We can freeze whichever layer we don’t want to train and pass the remaining layer parameters to the optimizer(we will see later).

if resnet:
if inception:
## Change the last layer
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, n_class)

Lets check what this model_conv has, In PyTorch there are children (containers) and each children has several childs (layers). Below is the example for resnet50,

for name, child in model_conv.named_children():
for name2, params in child.named_parameters():
print(name, name2)
## A long list of param are listed, some of them are shown below,
conv1 weight
bn1 weight
bn1 bias
fc weight
fc bias

Now if we want to freeze few layers before training, We can simple do using the following command:

## Freezing all layers
for params in model_conv.parameters():
params.requires_grad = False
## Freezing the first few layers. Here I am freezing the first 7 layers ct = 0
for name, child in model_conv.named_children():
ct += 1
if ct < 7:
for name2, params in child.named_parameters():
params.requires_grad = False

Changing the last layer to fit our new_data is a bit tricky and we need to carefully check how the underlying layers are represented. We have already seen for Resnet and Inception_V3. Lets check for other networks


There are two variants of squeeze-net in PyTorch and we can use any of it. Unlike resnet which has a fc layer in the end, Squeeze-net final layer is wrapped inside a container(Sequential) So we need to first list all the children layer inside it and convert the required layers according to our dataset and convert into back to a container and write it back to the class. This is explained in the below code in detailed manner.

model_conv = torchvision.models.squeezenet1_1()for name, params in model_conv.named_children():
## How many In_channels are there for the conv layer
in_ftrs = model_conv.classifier[1].in_channels
## How many Out_channels are there for the conv layer
out_ftrs = model_conv.classifier[1].out_channels
## Converting a sequential layer to list of layers
features = list(model_conv.classifier.children())
## Changing the conv layer to required dimension
features[1] = nn.Conv2d(in_ftrs, n_class, kernel_size,stride)
## Changing the pooling layer as per the architecture output
features[3] = nn.AvgPool2d(12, stride=1)
## Making a container to list all the layers
model_conv.classifier = nn.Sequential(*features)
## Mentioning the number of out_put classes
model_conv.num_classes = n_class


It is very similar to Resnet but the last layer is named as classifier. Below is the code

model_conv = torchvision.models.densenet121(pretrained='imagenet')
num_ftrs = model_conv.classifier.in_features
model_conv.classifier = nn.Linear(num_ftrs, n_class)

VGG and Alex-Net

It is similar to Squeeze-net. The last fc layers are wrapped inside a container, so we need to read that container and change the last fc layer as per our dataset requirements.

model_conv = torchvision.models.vgg19(pretrained='imagenet')# Number of filters in the bottleneck layer
num_ftrs = model_conv.classifier[6].in_features
# convert all the layers to list and remove the last one
features = list(model_conv.classifier.children())[:-1]
## Add the last layer based on the num of classes in our dataset
features.extend([nn.Linear(num_ftrs, n_class)])
## convert it into container and add it to our model class.
model_conv.classifier = nn.Sequential(*features)

We have seen how to freeze required layers and change the last layer for different networks. Now lets train the network using one of the nets. I am not going to mention this here in detail as it is already made available in my Github repo.

Base code

Like any deep learning model, We need to first

  • Define a network
  • Load pre-trained weights if available
  • Freeze the layers which you don’t want to train (freezed layers act as feature vector extractor)
  • Mention the loss
  • Choose the optimizer for training
  • Train the network until your defined criteria is met.

Now lets look how this done for inception_v3 in PyTorch. We will be freezing first few layers and train the network using an SGD optimizer with momentum and use Cross-Entropy loss.

Dataset Used


I have gone a bit further to check how these models perform under different setting. My idea was to train all the networks and see how individual models work and later apply different ensembling methods to improve the accuracy. I also thought of knowing how diverse are these models. So here are the metrics on the Train and validation dataset.


Cadene has trained several nets not available in Pytorch. I have used some of his code and trained the following models.

  • resnext101_64x4d
  • resnext101_32x4d
  • nasnetalarge
  • inceptionresnetv2
  • inceptionv4

I am facing issues with bninception and vggm. Will update soon


  1. Adding mixup strategy to all the networks.
  2. Ensembling model outputs
  3. Model stacking
  4. Extracting bottleneck features and using — ML to train the model
  5. Visualization using T-sne
  6. Solve issue with bninception(Model is not training)
  7. Train Vggm network
  8. SE-Net implementation and training.

Final Submission Results:

Standing at 29 on Kaggle Leaderboard at the time of submission.

GitHub Repo and stuff

This blog-post is not yet complete. I will add more stuff to this, Stay tuned.


Vikas Challa is working on Mix-Up. We will publish the results soon.

Co-Contributors: Vikas Challa and Sachin Chandra




Senior Data Scientist @Qure.AI | Ex Fractal.AI

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

How Machine Learning should be applied to Neurological Disease Research

Try mobile object detection by machine learning (YOLOv5) in 5 minutes

November ’19 DVC❤️Heartbeat

How To Fine Tune Your Machine Learning Models To Improve Forecasting Accuracy?

The Application of Artificial Neural Networks in Government

Machine Learning v/s Automation in Site Search for Ecommerce

Machine Learning and its types|| ML

Machines with brains like ours- An intro to Deep Learning

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Prakash Jay

Prakash Jay

Senior Data Scientist @Qure.AI | Ex Fractal.AI

More from Medium

Building a system that prevents drivers from falling asleep while driving using deep Learning

Deep Learning for Sign Language Production and Recognition

Domain Apposite Pre-processing to Improve Classification Performance

Stochastic Gradient Descent Using Pytorch Linear Module