Machine Learning for Unbalanced Datasets using Neural Networks

Can neural networks be used for binary classification in the case of unbalanced datasets?

Michael Kareev
Analytics Vidhya
10 min readSep 19, 2019

--

There are a few ways to address unbalanced datasets: from built-in class_weight in a logistic regression and sklearn estimators to manual oversampling, and SMOTE. We will look at whether neural networks can serve as a reliable out-of-the-box solution and what parameters can be tweaked to achieve a better performance.

Code is available on GitHub.

We’ll use the Framingham Heart Study data set from Kaggle for this exercise. It presents a binary classification problem in which we need to predict a value of the variable “TenYearCHD” (zero or one) that shows whether a patient will develop a heart disease. The majority (~85%) of the patients don’t have a condition, so it’s exactly the kind of a situation we’re interested in exploring.

The target variable is disbalanced

The dataset requires some cleansing that is out of the scope of this article and is discussed extensively here and here. That said, I’ll just put the required code below:

The next step is to create train and test splits:

Moving to the network itself:

We will start with a basic Sequential model with three layers:

The input data is vectors, and the labels are scalars. I’m choosing a fully connected (Dense) layer with a relu activation. The parameter units is the number of hidden units in this layer. In order to start with something, we are going to use 8. Input_dim provides information about the shape of your input. 15 is the number of features. You can easily check it for yourself:

Keras also allows you to pass input_shape() instead, and it should contain a tuple describing your data. In our scenario, I could have also used input_shape((15,)).

The second layer is similar to the first one. The final layer uses a sigmoid function because I want to get probability scores between 0 and 1 (that a given patient will have a heart condition). Later on, you will be able to round the probabilities to zeroes or ones depending on the desired threshold.

The next step is to compile the network, i.e. configure the future learning process. As a result, a Python object that builds an NN will be created. Keras supports various kinds of optimizers, and they can be further adjusted. We will start with Adam in our case. A loss function will be binary_crossentropy that is optimized for binary classification tasks. Finally, you can track various metrics by passing a list in metrics.

Then we will fit the model, make predictions, and check how accurate they are:

In the fit portion, I’ve added validation_split, which takes care of the validation process. The object history1 (returned by classifier.fit) contains a dictionary with the values of metrics (one for training and one for validation) that were chosen during the compile portion. It can be accessed like any other dictionary:

Values of the dictionary

It’s often more convenient to explore the results when they’re plotted:

Accuracy for training and validation sets
Loss for training and validation sets

Here’s a fairly interesting observation: our very first — and basic — model already overfits! We might have overoptimized some of the parameters. As a result, after the 60-70th epoch, the accuracy on the validation dataset starts decreasing, while the loss goes up. Maybe we don’t need so many epochs and should stop the fitting process a little bit earlier? Let’s find out:

I’m introducing an EarlyStopping callback that interrupts training once a target metric stops improving for a certain number of epochs that is controlled by patience. After that, we can print out updated charts.

Validation set can still do better

While we saved computer resources with early stopping, the 85% threshold hasn’t been achieved.

Another well-known method to deal with overfitting is L1/L2 regularization. Let’s explore!

The results are drastically different:

Accuracy on the validation set is improving

If you compare the old and the new chart:

The new model was trained on fewer epochs, that’s why the orange chart is longer

So, we have already achieved a better accuracy rate than the original model and have also surpassed the required threshold of 85%. You can predict the test set now:

You can speculate about the best threshold. Most often the best threshold depends on the nature of your problem. In our case, it’s probably better to falsely diagnose a disease and later find out that it’s a mistake than overlook the problem whatsoever. That said, the number of false negatives should ideally be low. It can be controlled by looking at recall_score (TP/(TP+FN)):

The existing model returns 96% as its recall score.

Overall, it seems that we were able to resolve the overfitting issue. If it wasn’t enough, we would combine the L2 regularization with dropouts:

After the overfitting is taken care of, we can work on improving the performance further. Let’s try tweaking the learning rate schedule. If you’ve ever used the SGD class, you might have seen such parameters as decay and lr. These are our optimization targets:

How Learning Rate can be controlled

This is how it can be implemented:

As you can see, we initialized the starting number of epochs, the learning rate, the decay rate, and the momentum manually and passed them into sgd_lr to use as an optimizer in the compile stage. It’s generally recommended to start with a larger learning rate and momentum than you would use in a normal scenario.

So far, everything we did was geared toward improving the model itself: first, handling overfitting, second, increasing its accuracy. We haven’t tried any methods that are specific to imbalanced datasets. Let’s see whether anything can help us. One of the simplest things to try will be class_weight. Think of it as oversampling.

In summary, you can combine various approaches together — let’s say, dropout and learning schedule, or early stopping, L2 regularization, and class_weight.

In addition, you might start with a smaller network (units = 4 in the first and second layers), change the optimizer from Adam to rmsprop, or, if you have enough computing power and patience, do a GridSearch on some of these parameters:

We have examined a few ways to better control your neural network when working with unbalanced datasets. We can achieve a 1–3% improvement by just tweaking the existing parameters. But moving above that requires some extra work with your data (think, SMOTE or upsampling).

--

--