Why use a pre-trained model rather than creating your own?

Florin-Daniel Cioloboc
Udacity PyTorch Challengers
5 min readJan 4, 2019

In the following paragraphs I’m going to motivate why you should consider using pre-trained models instead of creating one from scratch.

In order to effectively cover this course you should know what are neural networks and how they work and how to train them. If you covered everything from Lesson 1 and 3 from Udacity’s Intro to PyTorch then you should be in a good position for what comes next. You can fill in the gaps you might have from Lesson 4.

If you haven’t taken the course yet then you can still follow through if you have some knowledge on Neural Networks and minimum experience of implementing one in your favorite library (PyTorch preferably :P)

As for the working environment I prefer to use Jupyter based ones such as Kaggle Kernel or even Google Colab so you don’t need anything installed on your computer other than a browser if you want to test something.

What do you mean by pre-trained models?

A pre-trained model represents a model that was trained for a certain task on the ImageNet data set . In PyTorch’s case there are several very popular model architectures that are available to load into your notebook such as VGG, ResNet, DenseNet, Inception, among others. You can check the rest of the list here. This is known as transfer learning. Does that sound familiar? Sure, if you remember, we got a first glimpse of this in the Lesson 4, specifically notebook 7, and 8. [1]

Below is how to import a pre-trained model [2]

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)

What is Transfer Learning?

This whole process is called Transfer Learning and it’s actually a bit more than just importing a model into your working environment. One of the best summaries that I could find about why you should use Transfer Learning is from a very popular resource, the CS231n course by Stanford University:

“In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a data set of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.” — [3]

A few years ago, it was argued that Transfer Learning was the way to go for the industry and it still is a valid claim. [4][5]

Why should I use a pre-trained model?

If that didn’t convince you just yet, I’m going to offer a few more reasons to support my claim.

  • The issue with not using one is that you — depending on what experience you have — will spend a serious amount of time training your model from scratch. You will have to do plenty of calculations and experiments to build a proper CNN architecture. Think about how many design questions you will have to sort out: How many layers do I need? What about pooling? Do I put in stacks? Moreover, the complexity of the data set will weigh in as well.
  • You might not have a data set that is large enough so that your model can generalize well enough and you might not have the computational resources for that either.
  • Keep in mind that ImageNet has 1000 classes so the pre-trained models have been trained to work on a lot of different things.
  • The hard work of optimizing the parameters has already been done for you, now what you have to do is fine-tune the model by playing with the hyperparameters so in that sense, a pre-trained model is a life-saver.

Now, fine-tuning a model can be definitively covered in another article and I won’t really go into it here. However, I can definitively suggest you where to read about fine-tuning. For starts I’d recommend: PyTorch’s tutorial, my repository on the course, and another resource but for Keras (still useful for ideas).

Some strategies for retraining a pre-trained model from Stanford’s CS230 [6]

Will a pre-trained model always fit my problem?

It depends on …

  • If you are doing research on building a different model then you would probably not be reading this ;)
  • If you are doing it for experimentation purposes then you might be interested in loading the architecture with randomly initialized parameters, so no.
  • If you are confronted with a problem where your data set is very different from ImageNet and you’ve already tried retraining a pre-trained model already, then of course no.
  • If you are interested in learning how CNNs work then you would focus your efforts on building on from scratch.

Actually, the reasons for not using ones are pretty simple and straightforward and it can probably go on but these were the ones that came to my mind while writing.

How do I decide what pre-trained model architecture to use?

I saved this for the last as I believe this is one of the most challenging aspects of it. It’s one of those questions that we all have, especially as beginners. To be honest, I wouldn’t know what to say as I assume it most likely depends on what problem you’re trying to solve, however to give you some insight into how I would approach this question.

First, I would definitively research the problem by stating the key words of my question and also by what data set I’m using. Why the data though? Depending on how complex your data set is some models might work better than others, there are many factors that weigh in here. How deep does it need to be?

Depending on what model architecture I find most useful I try to read carefully on what benchmarks they achieved during experiments. PyTorch and other frameworks provide top-1 and top-5 error rates on the ImageNet for their pre-trained models so essentially, this can also be a good starting point.[1][7]

Lastly, what I have presented to you is just a small part of Transfer Learning and it barely scratches the surface. From this point on, I can recommend you to check CS231n, Sebastian Ruder’s blog, and Model Zoo to know more.

I encourage you to explore different pre-trained models and even develop one from scratch to beat your best result.

Resources that I’ve used:

  1. Udacity Deep Learning repo: https://github.com/udacity/deep-learning-v2-pytorch/blob/master/intro-to-pytorch/Part%208%20-%20Transfer%20Learning%20(Exercises).ipynb
  2. PyTorch models documentation: https://pytorch.org/docs/stable/torchvision/models.html
  3. TransferLearning: http://cs231n.github.io/transfer-learning/#tf
  4. Sebastian Ruder’s blog: http://ruder.io/transfer-learning/index.html#whytransferlearningnow
  5. CNNs: http://cs231n.github.io/convolutional-networks/
  6. Deep Learning cheatsheet: https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-deep-learning-tips-and-tricks#parameter-tuning
  7. Model architecture benchmarks: https://github.com/jcjohnson/cnn-benchmarks

--

--