Training EfficientDet on custom data with PyTorch-Lightning (using an EfficientNetv2 backbone)

Chris Hughes
Data Science at Microsoft
13 min readJul 27, 2021

Object detection remains a popular and challenging area in computer vision, and for good reason: It can be applied to a wide variety of real-world scenarios, ranging from recognizing road signs during autonomous driving to detecting defects at the end of a manufacturing production line — and even identifying areas of interest from medical scans. As such, we can expect the interest in this area to continue to grow.

While there is an ever-growing selection of object detection models to choose from, a popular choice for many applications is EfficientDet, released by the Google Brain team in July 2020. At the time of publication, the largest model (EfficientDet-D7) demonstrated state-of-the-art results, and almost a year later it remains a popular choice, performing competitively in recent Kaggle competitions such as Global Wheat Detection and NFL 1st and Future — Impact Detection.

Comparison of different object detection architectures. Source: EfficientDet: Scalable and Efficient Object Detection (arxiv.org)

Despite this, when recently working on an object detection project as part of Microsoft Commercial Software Engineering (CSE), I struggled to find a PyTorch implementation of EfficientDet that could be quickly and easily adapted to my problem; with the implementations that I came across either being coupled to a particular dataset during training, lacking features such as multi-GPU training, or being incompatible with the latest versions of the underlying packages that they were based upon.

Here, I aim to provide a clean and clear starting point for anyone wishing to experiment with EfficientDet by providing a bare-bones implementation, using PyTorch-Lightning, which can be easily adapted to new problems. I focus purely on the implementation, and not on how EfficientDet works, as there are many excellent blog posts in this area already. Also, as EfficientNetv2 has recently been released, I thought it would be fun to show how this can be used as the backbone for EfficientDet in place of EfficientNetv1.

Comparison of EfficientNetv2 and other recent computer vision models. Source: EfficientNetV2: Smaller Models and Faster Training (arxiv.org)

To keep the code as clear and minimal as possible, I opted not to include any metrics other than the training and validation losses, as these can be added later depending on your specific problem with little effort; I briefly discuss how to do this later. For general object detection metrics, I recommend a repo recently created by my colleague that provides a clean interface to the pycocotools package.

Tl;dr If you just want to see some working code that you can use directly, all of the code required to replicate this post is available as a GitHub gist here. While gists are used as code snippets throughout this article, this is primarily for aesthetic reasons, and these snippets may not work as intended if copied directly. For working implementations, please defer to the notebook in the gist linked above.

The packages used are:

Package versions used when running this code.

Acknowledgments

Before beginning, I want to highlight a couple of excellent resources that inspired this work:

  • For the model implementation and pretrained weights, this work heavily utilizes Ross Wightman’s awesome EfficientDet-Pytorch (effdet) and pytorch-image-models (timm) packages. Ross’s dedication to providing implementations of state-of-the-art computer vision models that are easily accessible to the whole data science community is second to none. If you haven’t already, go and add stars!
  • Alex Shonenkov has a clear and concise Kaggle kernel that illustrates fine-tuning EfficientDet to detecting wheat heads using EfficientDet-PyTorch; it appears to be the starting point for most similar Kaggle solutions that I came across. Even though it’s out of date with the current version of EfficientDet-PyTorch, it still acted as a valuable foundation for this work.
  • Most of the powerful, out-of-the-box features of this implementation come as a direct result of using PyTorch-Lightning. As there are many blog posts outlining the features and advantages of using lightning, I do not cover it here, but for those who are unfamiliar, lightning essentially provides a clean interface that you can use to organize your vanilla PyTorch code, which then enables features such as gradient accumulation and distributed training with no further code changes — awesome, right?!

Selecting a dataset

As an example, I use the Kaggle cars object detection dataset; however, as my aim is to demonstrate how EfficientDet can be applied to any problem, this is really the least important part of this work.

As this dataset is quite small, and the designated test set is unlabeled, for simplicity I focus on training and evaluating the model on the training set. While this is never something that should be done in practice as a method to evaluate a model’s performance, it is a useful trick to test whether the model is capable of learning the task; if the model is unable to overfit on the training set, it clearly does not have the capability to learn the task and we can expect poor performance when the model is used on new data.

The annotations for this dataset are in the form of a .csv file, which associates the image name with the corresponding annotations, and we can view the format of this by loading it into a pandas DataFrame.

Here, we can see that each row associates the image filename with a bounding box in pascal VOC format.

Creating a dataset adaptor

Usually, at this point, we would create a PyTorch dataset to feed this data into the training loop. However, some of this code, such as normalizing the image and transforming the labels into the required format, are not specific to this problem and will need to be applied regardless of which dataset is being used. Therefore, let’s focus for now on creating a CarsDatasetAdaptor class, which converts the specific raw dataset format into an image and corresponding annotations. An implementation of this is presented below:

From this, we can see that that the main functionality we have implemented is the get_image_and_labels_by_idx method, which returns a tuple containing:

  • image: A PIL image
  • pascal_bboxes: a numpy array of shape [N, 4] containing the ground truth bounding boxes in Pascal VOC format
  • class_labels: a numpy array of shape N containing the ground truth class labels
  • image_id: a unique identifier that can be used to identify the image and the __len__ method

In this case, this class simply wraps the DataFrame provided with the dataset. As this dataset contains only a single class, a label of 1 is always returned. For EfficientDet, the classes should start at 1, with -1 being used for the “background” class.

Additionally, as the image_id can be any unique identifier associated with the image, here we have just used the index of the image in the dataset. We have also implemented a show_image method for convenience. The function that is called to display the image is defined in the corresponding notebook.

We can now create an instance of this class to provide a clean interface to view the training data, a selection of which is seen below:

More generally, we could think of the dataset adaptor as an abstract class with the following interface:

As we are in Python and have the luxury of “duck typing,” actually defining this base class and subclassing seems a little excessive, but hopefully it helps illustrate the required interface!

When using a different dataset, this is the part that changes!

Creating the model

Now, let’s look at creating the EfficientDet model. Thanks to Ross Wightman’s effdet and timm libraries, we have many options here. The effdet package includes a selection of different EfficientDet configurations that can be used. We can view a selection of these below.

Some of these implementations (i.e., efficientdet_d5) have been trained by Ross in PyTorch, whereas any implementation prefixed by “tf_” uses the official pretrained weights. As the initial models were trained in TensorFlow, to use these weights in PyTorch, certain modifications have been made (such as implementing “same” padding), which means that these models may be slower during training and inference.

In addition to the provided configs, we can also use any model from timm as our EfficientDet backbone. Here, let’s try using one of the new EfficientNetv2 models as the backbone. Similar to before, we can list these models using timm:

To use one of these models, we first must register it as an EfficientDet config by adding a dictionary to the “efficientdet_model_param_dict”. Let’s create a function that does this for us, and then creates the EfficientDet model using the machinery from effdet:

Here, once the EfficientDet model is created, we modify the classification head based on the number of classes for our current task. We have set the default image size at 512, as used in the paper. Due to the architecture of EfficientDet, the input image size must be divisible by 128. Here, we use the default size of 512. We can now use this to create our PyTorch EfficientDet model.

Creating an EfficientDet Dataset and DataModule

Now, let’s move on to loading data that can be fed into our model. Let’s start by defining some transforms, which need to be applied before passing the images and labels into the model. For this, we can use the excellent Albumentations library, which contains a wide variety of data augmentation methods.

Here, with the aim of keeping things simple, we keep only the essential pre-processing during validation — as the backbone was pretrained, we need to normalize the image using the mean and standard deviation of the ImageNet dataset, as well as resize the image and convert it to a tensor — and add a horizontal flip while training. We can see that we have to pass the bounding boxes to these transforms, because Albumentations also applies any transformations to the labels!

As we can see, it would be a straightforward task to add additional augmentations!

Now that we have defined our transforms, we can move on to defining a dataset to wrap our dataset adaptor and apply the transformations. The only gotcha to watch out for is that EfficientDet requires the bounding boxes in YXYX format. We can see the implementation of this below:

While we could now use this dataset to create a standard PyTorch DataLoader, PyTorch-lightning provides a DataModule class, which we can use to group all our data-related components. The interface is quite intuitive, and we can implement as follows:

As well as helping to keep our code clean, as we see later, we can use this directly during training instead of having to manually create the Dataset and DataLoader components.

Defining the training loop

OK, still with me? We’re almost there! In PyTorch-lightning, we tie the model, training loop, and optimizer together in a LightningModule. So instead of having to define our own loops to iterate over each DataLoader, we can do the following:

The important parts of the training loop are captured in the “training_step” and “validation_step” methods. Later, we add a couple extra methods to this class to handle inference, but for now, we have all that we need to start training!

Training the model

OK, let’s train the model! Because we have used PyTorch-Lightning for our training loop and data loaders, this part is super easy!

Even better, although we are using only a single GPU for this example, we could easily train this model on multiple GPUs — even across multiple nodes — with no additional code changes other than changing the Trainer arguments!

Inference

Great, we’ve trained the model! Now let’s look at using it to make some predictions. Let’s start by adding a predict function to the lightning module that we defined earlier. Here, I have used the typedispatch decorator from fastcore to overload the predict method depending on the input type.

As we can see, both methods essentially transform the input data into the required format before delegating to an internal “_run_inference” method. Let’s dive into that method and see what’s going on.

From the implementation we can see that there are several steps involved:

  • effdet has two classes that wrap the EfficientDet model: DetBenchTrain and DetBenchPredict. As DetBench train is designed to be used for training, when passing images to the model, it also requires a set of targets, whereas DetBenchPredict does not. However, the implementations of these classes are very similar, and to avoid having to break apart the model post-training before using it for inference, I thought it simpler to build the inference capabilities around the same class that we use to train. To do this, we first create a set of dummy targets to pass to the model, to satisfy the DetBenchTrain class. As we are not interested in the loss during inference, the values of these are arbitrary.
  • The model returns a vector containing the bounding boxes, the confidence scores, and the classes for each image. We use a couple of convenience methods to unpack these, and only return the predictions where the score is higher than a defined threshold. We also apply weighted box fusion to the prediction, so overlapping boxes are combined.
  • As the images are resized before they are fed to the model, the bounding boxes returned are relative to the resized images. To use the predicted boxes with the original image, we resize the bounding boxes before they are returned.

We can now use the predict function to see how the model performs:

Plotting the predicted boxes, we can compare these to the ground truths:

After five epochs, we can clearly see that the model has learned the task!

Using Model hooks for manual debugging

Validation step outputs and adding COCO metrics

One feature of PyTorch lightning is that it uses methods, or “hooks”, to represent each part of the training process. While we lose some visibility over our training loop when using the Trainer, we can use these hooks to easily debug each step.

For example, we can use a hook defined on our DataModule to get the DataLoader that is used during validation and use this to grab a batch.

We can use this batch to see exactly what the model calculated during validation. Using the model’s hook, we can see what is calculated for each batch during each validation step.

Here, we can see that the loss is returned for the batch, as well as the predictions and targets. To calculate metrics for the epoch, we need to get the predictions corresponding to each batch. As the “validation_step” method will be called for each batch, let’s define a function to aggregate the outputs.

Here, for simplicity, we patch this function to the EfficientDet class using a convenience decorator from fastcore — we pay a high performance price for Python being a dynamic language, so we may as well make the most of it!

From the PyTorch-lightning docs, we can see that we can add an additional hook “validation_epoch_end” that is called after all batches have been processed; at the end of each epoch, a list of step outputs are passed to this hook. We can define this as follows:

Let’s use this hook to calculate the overall validation loss, as well as the COCO metrics using the “objdetecteval” package. To illustrate how this would work, we can use the output that we just calculated when evaluating a single validation batch, but this approach would also extend to the validation loop evaluation during training with lightning.

Here, we can see that as well as the average validation loss across all batches (in this case only one), COCO metrics are also returned.

Using hooks for inference

We can also use the predict function directly on the processed images returned from our data loader. Let’s now unpack the batch to just get the images, as we don’t need the labels for inference.

Thanks to the “typedispatch” decorator, we can use the same predict function signature on these tensors.

It is important to note at this point that the images given by the DataLoader have already been transformed and scaled to size 512. Therefore, the bounding boxes predicted are relative for an image of 512. As such, to visualize these predictions on the original image, we must resize it.

As we can see, after resizing, the bounding box is in the correct position!

Conclusion

I hope you have found this article useful, and that it can act as a starting point as you experiment with EfficientDet in PyTorch!

All of the code required to replicate this post is available as a GitHub gist here. While gists are used as code snippets throughout this article, this is primarily for aesthetic reasons, and these snippets may not work as intended if copied directly. For working implementations, please defer to that gist.

Chris Hughes is on LinkedIn.

References

--

--

Chris Hughes
Data Science at Microsoft

Principal Machine Learning Engineer/Scientist Manager at Microsoft. All opinions are my own.