Photo by Sergei Akulich on Unsplash

Tree-based Machine Learning Models for Handling Imbalanced Datasets

A quick and dirty introduction to 3 of them

Li Sisters
Published in
5 min readApr 7, 2020

--

Recently, I have been working on a binary classification problem with an imbalanced dataset, where the ratio of positive class to negative class is around 1:4. Imbalanced classification problems are so commonplace that data enthusiasts would encounter them sooner or later. In this post, I will be sharing three tree-based Machine Learning Models that can help handle imbalanced datasets.

The dataset that I am going to use to illustrate the effectiveness of algorithms is the credit card fraud dataset from Kaggle. This is an extremely imbalanced dataset: out of 284,807 transactions, there are only 492 frauds. Following the convention, we label the fraud class samples as positive class and normal transactions, negative class. The positive class accounts for 0.172% of all transactions. The input variables are all numerical, being the output of a Principal Component Analysis. It is free from null values and quite clean.

The dataset is split into training set and test set, with the ratio being 80%:20%.

Let’s train a logistic regression classifier to get a first feeling:

from sklearn.linear_model import LogisticRegression
logreg = LogisticRegression()
logreg.fit(X_train,y_train)

print_report is a function that I wrote to include metrics like confusion matrix, Matthews Correlation Coefficient (MCC) and classification report. At first glance, the logistic regression classifier finds frauds more difficult to classify. Recall score is around 25% less than precision score in training and test sets.

Weighted Decision Tree

There are mainly two approaches for modelling imbalanced data: cost-sensitive learning and resampling. Weighted Decision Tree employs the first approach, by assigning minority class with a higher cost (weight), and majority class with a lower cost. Scikit-learn library provides users with a class_weight parameter that can be adjusted to make the Decision Tree ‘weighted’. One can specify weights, probably based on domain knowledge about the problem, or can use the default heuristic class_weight=‘balanced’, where the weights for different classes will be inversely proportional to class distribution. In the context of the credit card fraud detection dataset, this means that the ratio of weights assigned to fraud (positive) and normal transaction (negative) would be (1–0.172%):0.172%.

Now, let’s construct the model and check its performance on the normalised dataset:

from sklearn.tree import DecisionTreeClassifier
wdt = DecisionTreeClassifier(class_weight='balanced', random_state=42)
Classification performance of Weighted Decision Tree

Without pruning, the weighted decision tree is overfitting the training set. But we shall leave it like this for simplicity of this post, hyperparameters tuning is another topic to be covered :P

One important point should be pointed out: the weighted decision tree classifier has a more ‘balanced’ performance (pun intended!) on precision and recall. By assigning a much heavier weight on positive class, recall is improved by 10% compared to logistic regression classifier at the expense of precision.

Weighted Random Forest

Random forests are an ensemble learning method with decision trees being its building blocks. Weighted Random Forest, as a variant of random forest, consists of multiple weighted decision trees. The final output by the weighted random forest is the class that have the majority votes from individual weighted decision trees.

Now, let’s see how it works on the credit card fraud detection problem:

from sklearn.ensemble import RandomForestClassifier
wrf = RandomForestClassifier(class_weight='balanced_subsample', random_state=42)
Classification performance of Weighted Random Forest

Weighted random forest performs better than weighted decision tree generally, especially on classifying majority class samples. The MCC is 0.8575, indicating that the predictions are close to true labels. (Recall, MCC ranges from -1 to 1.)

Balanced Random Forest

Balanced random forest employs resampling method interestingly. When we consider resampling, we tend to think of massaging the training set in a way to make it more ‘balanced’ by adding more minority class samples or removing majority class samples. This modification is done before training the model. However, in balanced random forest, the original training set is being passed to the model. As algorithm progresses, each bootstrap sample generated is balanced by being undersampled randomly. (Recall: bootstrapping is sampling with replacement. Each bootstrap is used for constructing individual decision tree in the forest.) The imbalanced-learn library has implemented the algorithm (hooray!).

Let’s see how it works:

from imblearn.ensemble import BalancedRandomForestClassifier
brf = BalancedRandomForestClassifier(random_state=42)
Classification performance of Balanced Random Forest

Wow! Balanced Random Forest did really badly in classifying the majority class samples (the normal transactions). One advantage gained by sacrificing precision: its recall score is the highest among the classifiers presented in this post. For extremely imbalanced datasets, each bootstrap contains few or even none of the minority class samples. Individual trees in the forest may be bad predictors themselves, let alone contributing a meaningful vote. When a bootstrap is being under-sampled, too many normal transactions are discarded, resulting in a great loss of information from which the model can learn to classify majority class samples correctly.

Conclusion

For this particular problem, weighted random forest outperforms logistic regression classifiers, weighted decision tree and balanced random forest in terms of Matthews Correlation Coefficient. Variants of random forests tend not to overfit the training set as much as decision trees do, as such they are preferred when the dataset at hand is huge. For extremely imbalanced dataset, the cost-sensitive approach might be better than the resampling approach.

Hopefully, this post inspires data enthusiasts to explore more algorithms tailored for imbalanced classification problems. That being said, it is definitely worth checking out the imbalanced-learn library!

--

--

Li Sisters
The Startup

Account shared between two of us | Android Developer + Data Enthusiast :D