Multi-class classification with focal loss for imbalanced datasets

Chengwei Zhang
The Startup
Published in
4 min readDec 15, 2018
Focus on hard examples

The focal loss was proposed for dense object detection task early this year. It enables training highly accurate dense object detectors with an imbalance between foreground and background classes at 1:1000 scale. This tutorial will show you how to apply focal loss to train a multi-class classifier model given highly imbalanced datasets.

Background

Let’s first take a look at other treatments for imbalanced datasets, and how focal loss comes to solve the issue.

In multi-class classification, a balanced dataset has target labels that are evenly distributed. If one class has overwhelmingly more samples than another, it can be seen as an imbalanced dataset. This imbalance causes two problems:

  • Training is inefficient as most samples are easy examples that contribute no useful learning signal;
  • The easy examples can overwhelm training and lead to degenerate models.

A common solution is to perform some form of hard negative mining that samples hard examples during training or more complex sampling/reweighing schemes.

For image classification specific, data augmentation techniques are also variable to create synthetic data for under-represented classes.

The focal loss is designed to address class imbalance by down-weighting inliers (easy examples) such that their contribution to the total loss is small even if their number is large. It focuses on training a sparse set of hard examples.

Apply focal loss to fraud detection task

For demonstration, we will build a classifier for the fraud detection dataset on Kaggle with extreme class imbalance with total 6354407 normal and 8213 fraud cases, or 733:1. With such highly imbalanced datasets, the model can just take the easy route by guessing “normal” for all inputs to achieve an accuracy of 733/(733+1) = 99.86%. However, we want the model to detect the rare fraud cases.

To prove the focal loss to be more effective than commonly applied techniques, let’s set up a baseline model trained with class_weight which tells the model to “pay more attention” to samples from an under-represented fraud class.

Baseline model

The baseline model achieved an accuracy of 99.87%, just slightly better than taking the “easy route” by guessing all normal.

We also plot the confusion matrix to describe the performance of a classifier given the reserved test set. You can see there are total 1140+480=1620 miss-classified cases.

Confusing matrix — baseline model

Now let’s apply focal loss to the same model. You can see how to define the focal loss as a custom loss function for Keras below.

focal loss model

There are two adjustable parameters for focal loss.

  • The focusing parameter γ(gamma) smoothly adjusts the rate at which easy examples are down-weighted. When γ = 0, focal loss is equivalent to categorical cross-entropy, and as γ is increased the effect of the modulating factor is likewise increased (γ = 2 works best in experiments).
  • α(alpha): balances focal loss, yields slightly improved accuracy over the non-α-balanced form.

Now let’s compare the performance with the previous classifier.

Focal loss model:

  • Accuracy: 99.94%
  • A total miss-classified test set samples: 766+23=789, cutting down the mistakes by half.
Confusing matrix — focal loss model

Conclusion and further reading.

In this quick tutorial, we introduced a new tool for your arsenal to handle a highly imbalanced dataset — focal loss. A concrete example shows you how to adopt the focal loss to your classification model in Keras API.

You can find the full source code for this post on my GitHub.

For a detailed description of focal loss, you can read the paper, https://arxiv.org/abs/1708.02002.

Share on Twitter Share on Facebook

Originally published at www.dlology.com.

This story is published in The Startup, Medium’s largest entrepreneurship publication followed by +399,714 people.

Subscribe to receive our top stories here.

--

--

The Startup
The Startup

Published in The Startup

Get smarter at building your thing. Follow to join The Startup’s +8 million monthly readers & +772K followers.

Chengwei Zhang
Chengwei Zhang

Written by Chengwei Zhang

Programmer and maker. Love to write deep learning articles.| Website: https://www.DLology.com | GitHub: https://github.com/Tony607