Should you use FastAI?

Thiago Dantas
deeplearningbrasilia
5 min readFeb 28, 2020
Photo by Markus Spiske from Pexels

Recently, I’ve been studying Deep Learning with Pytorch and FastAI. I always get impressed how FastAI models are fast to train and still can deliver a really good performance without requiring much coding, so I wanted to see what I can get with pure Pytorch versus what I can get with FastAI.

In order to empirically test if FastAI is really that good I chose a kaggle dataset that consists of several Chest X-Ray images for classifying pneumonia.

1. Pure Pytorch

With the data downloaded from kaggle, the first thing to do is to import all packages that we’ll need.

1.1 Loading the data

Now we have to define the transformations for our data augmentation system. Here I’ll define some really simple transforms.

The data I need to fit my model is organized in folders. There is a folder named “train” and a folder named “validation”. Inside each one of those are two folders named “Normal” and “Pneumonia” which have the images for training. This is a really common way to organize data for image classification. Thankfully, Pytorch has functionality that can easily load this into a Pytorch dataset.

With the datasets and dataloaders created we can plot some data in a batch to see if everything seems in order. I’ll define a function called “show_batch” to do so.

1.2 Creating and training a model

With the datasets and dataloaders defined now I have to define a model. Pytorch provides several world class CNNs pretrained on Imagenet, so I’ll use the resnet50 model pretrained on Imagenet. I’ll use the CNN only as a feature extractor, so I’ll only train the fully connected layer of the network. In order to train a model I have to define the model, define a criterion (a loss function that will guide the training), an optimizer and a learning rate scheduler.

Finally, I have to define the training function.

Calling this function and training for 3 epochs the model achieves 93.60% accuracy and takes 3min52s to train. Pretty good, huh? Now let’s see if an hybrid Pytorch-FastAI model can do it better.

2. Hybrid Pytorch-FastAI

Now, the only thing that will differ from the pure Pytorch model is that I’ll use FastAI to train my model. I’ll use the same transforms, same dataset and same model. Really, the only thing that differs is the training function.

Now I have everything I need to train this model using FastAI functionality. FastAI learners have a really handy method, “lr_find()”. This method makes a search for the best learning rate to fit the model.

Jeremy Howard, founder of FastAI, suggests that a good learning rate is one that is one decade lower than the minimum of the plot. So 0.01 seems like a good learning rate.

Now, we have everything we need to fit the model. FastAI has a fit method called “fit_one_cycle” which is based on this paper (you can check this link for a simpler explanation). Basically, the one cycle policy leads to faster training.

Calling this method and training for 3 epochs the model achieves 94.79% accuracy and takes 3min50s to train. The FastAI-Pytorch hybrid model takes about the same time to train as the pure Pytorch model but it achieves a higher accuracy. Finally let’s see what accuracy a pure FastAI model achieves.

3. Pure FastAI

Now, I won’t use the transforms I defined previously and I’ll let FastAI take care of everything. Let’s see how much code we need to obtain a world class result.

This is all the code we need to get the data in a suitable format for training and to define a model. Now let’s do the same approach as before: search for the best learning rate and train for 3 epochs.

Again, 0.01 seems like a good learning rate. Now let’s call the fit method.

Calling this method and training for 3 epochs the model achieves 96.58% accuracy and takes about 8min to train.

Edit: you may be questioning why the pure FastAI model takes 8 min to train and the others take about 4min. I was studying more about the FastAI library and I found out that while doing transfer learning for a classification task, FastAI doesn’t fit only the fully connected layer but also all the batch norm layers in the model. With more parameters to fit it makes sense that the training process will last longer. Also, it makes sense that you should fit the batch norm layers as the images we’re training on are considerably different than the ones in Imagenet. Probably this is one of the reasons why the pure FastAI model got the greatest accuracy.

4. Conclusion

To wrap up, the pure FastAI model, with an impressive 96.58% accuracy, topped both Pytorch-FastAI hybrid model and its 94.79% accuracy and the the pure Pytorch model, that obtained “only” a 93.60% accuracy.

It is really nice to see how FastAI can help you get better results for your models. Imagine that the problem you are working with is not image classification and is not so easy to create a databunch with FastAI factory methods. Nevertheless, you can define your custom Pytorch dataset and dataloader and load them into a databunch. You also can define you very complicated model, your custom loss function, custom optimizer and train your model with FastAI’s “fit_one_cycle” method, that has been proved to be better than a standard fit function.. In this link I’ve done that.

Finally, I have to confess that I cheated with the pure Pytorch model. I defined the learning rate for this model after I trained the Pytorch-FastAI hybrid model, so I already knew that 0.01 was a good learning rate. Probably if I had trained the pure Pytorch model with another learning rate the FastAI models would be even better.

Ps: you can find all the code in this repo.

--

--