Five mistakes to avoid when modeling with imbalanced datasets
And what to try instead
By Phillip Adkins, Michele Garner, and Dave Kooistra
Got 99 records — and Trues: only one
Welcome to the world of imbalanced datasets.
As data science professionals, it is not uncommon to encounter target dependent variables that occur so infrequently in the dataset they might as well be outliers. Some example data sets with this issue might include fraud versus non-fraud credit card transactions, cancer screenings, or mechanical failures.
Unfortunately, much of the available imbalanced model–fitting advice doesn’t actually lead to better models.
To keep things simple, let’s focus on binary classification. Here are five common faux-pas and their fixes to result in higher quality models when working with imbalanced datasets, along with examples of how making these mistakes could have an impact on the model’s outcome.
Mistake 1: During first approach, defaulting to dataset doctoring
Let’s suppose there’s a project to design a binary classification model to identify emails as fraudulent (spam) or not fraudulent (ham). After doing some initial data exploration, findings show the dataset is imbalanced with 100 times more ham examples than spam examples.
It may be tempting to start the modeling process by resampling the dataset. Two common methods include downsampling (the removal of some of the majority class) and upsampling (the addition of duplicate records from the minority class).
Intuitively, resampling the dataset may seem like the right approach. However, a sample is meant to be representative of the full population of the thing it’s measuring, and downsampling means losing data. This is especially true if the remaining dataset isn’t rich enough to be representative of the population, in which case it could lead to a model that underperforms during its application. Upsampling is a safer choice but has its own issues, such as potentially blowing up the size and memory consumption of the data.
More importantly, resampling and reweighting are often sources of common errors that can decrease model accuracy anywhere from a minor dip to a catastrophic plunge, yield misleading evaluation statistics, or make iterating on model building substantially less efficient.
Overall, we suggest that resampling is generally an anti-pattern that should be avoided unless there’s a good reason to apply it. Another technique related to resampling is “reweighting” or “sample weighting.” While not quite as prone to inducing errors as resampling, we also consider using reweighting as a default approach for modeling on imbalanced datasets an anti-pattern.
What to try instead
Instead of resampling the data, leave the data unbalanced. The techniques outlined below provide stronger alternatives to downsampling and upsampling.
Mistake 2: Relying on “predict” to predict
For models with scikit-learn–style APIs, using the “predict” function of a binary classifier produces a discrete class prediction (e.g., a 0 or 1). What some people overlook is that the “predict” function is generally based on a heuristic that applies a default internal decision threshold of 0.5 to the model’s decision function. In other words, if the model’s decision function is 0.5 or greater, “predict” will yield a “True” or a 1 and otherwise it will return a “False” or a 0.
However, this heuristic often isn’t even close to the best you can do on a given dataset — and is often just plain bad on imbalanced data.
Relying on the “predict” function when tackling a model fit on an imbalanced dataset may cause the model to appear as if it’s struggling due to the built-in default decision threshold. This means that the model might be well fit, but the “predict” function isn’t pulling the best binary prediction out of it.
What to try instead
A better solution for training, testing, and validating models derived from unbalanced datasets is to tune and determine a custom decision threshold to optimize a cost function relevant to your specific use case.
Luckily, most libraries and packages come with a function to predict outcome probabilities or decision functions for each class, such as scikit-learn’s “predict_proba”. Instead of using a built-in decision function, the decision function can be used with a decision threshold set by the data practitioner to maximize predictive performance. The threshold allows a Machine Learning practitioner the freedom to sculpt and optimize the model’s output behavior.
For example, on the previously mentioned spam/ham project, you may be seeing an f1 score of 0.0 on your validation or test sets while using the predict function. Rather than assuming this is due to the model’s inability to focus in on the rare spam items, you might find that using the decision function and a well-chosen threshold will increase your f1.
In the “Relying on ‘predict’ to predict” experiment in the linked experiments notebook, we demonstrate a case in which the f1-score we achieve using “predict” is 0.16 while we are able to obtain an f1-score of 0.43 with a custom decision function.
This result is not an exception. In fact, it can be almost guaranteed to occur every time using simple reasoning. The performance of the model using “predict” is contained as a subset of the results available when auditioning a variety of decision thresholds. Because max(V) >= v if v is in V, of course we can obtain better performance through tuning the decision threshold!
Mistake 3: Relying on algorithm defaults
The impulse to resample datasets may stem from faulty reasoning that the Machine Learning modeling techniques will have an easier time learning from balanced data. In reality, most Machine Learning algorithms actually learn brilliantly well from unbalanced data — as long as they are tuned properly.
Using the default hyperparameters for various algorithms may result in poor performance and require longer training times on imbalanced datasets.
What to try instead
Adjusting the right model hyperparameters can significantly improve performance on imbalanced data. Some algorithms even have hyperparameters that specifically address class balance, which then speeds up training time and improves performance.
For example, XGBoost works faster and more accurately when “base_score” is set to the ratio of positive samples in the data. The default XGBoost hyperparams cause the algorithm to spend many boosting iterations learning the bias, slowing learning and potentially causing the algorithm to take errant turns along the way. Setting the right constant initialization allows the algorithm to focus on learning variations in the data rather than approximating a constant.
We demonstrate this in the linked experiments notebook. Using a dataset with a class balance ratio of 99 to 1, we compare a default XGBoost base_score to our recommended XGBoost initialization.
As you can see from the above log loss and f1_score plots (which are a function of the number of boosting iterations), not only does the properly initialized model converge much more quickly — the ultimate values of log loss and f1_score it converges to are substantially better. While there may not be a mathematical guarantee that the solution XGBoost converges on with this initialization will be better, we have seen it play out that way every time in practice.
Mistake 4: While building a model, resampling or reweighting to control output class statistics
Another common scenario encountered when modeling on imbalanced data is that the relative balance of predictions output by the model for each class isn’t desirable. Often, data practitioners will resort to resampling or reweighting the training set to coax the model to predict more or less of the minority class.
Resampling or reweighting isn’t necessary to accomplish this and will only increase both the number of experiment iterations and the chances of making a mistake in your evaluation statistics or affect the calibration of your model.
What to try instead
Similar to what’s been suggested previously, sculpt your output class balance using a custom decision threshold applied to your model’s decision function.
This is a simple and direct way to control the exact ratio of positives / negatives output by your model and is much safer and more computationally efficient than repeatedly retraining your model with different resampling or reweighting parameters.
In the experiment “Resampling or Reweighting to Control Output Statistics” in the linked notebook, we perform the following experiment:
- Obtain a particular value of “recall” using resampling.
- Show that we can find a decision threshold using a model trained on the raw data that attains that same recall.
- Repeat for a variety of recall values attained through resampling and show that we can match it every time.
- Compare precision values.
In this case, we find that the model trained on the resampled dataset is able to achieve equal recall in each case with a precision value that is equal to or exceeds that which is obtained through resampling.
In large part, we’re able to not only match but beat the precision of the resampling strategy each time because we’re downsampling (the most common resampling strategy). Downsampling throws away data — and sometimes a lot of it if you’ve got a highly imbalanced dataset. You would expect a model trained on more data would perform better in general, which is what we’re seeing here.
You do not need to mutilate your dataset to get the recall you want! Just remember to use a tuned decision threshold.
Mistake 5: Tuning and evaluating model performance on resampled validation or test sets
Typically, datasets are split into three groups:
- A training set for training the model.
- A validation set to repeatedly validate the model’s performance on data not used in training.
- A test set to be used one time at the end to simulate the model’s performance in a production environment.
When resampling, it is important to resample only the training set and not the validation or test sets. This means that one must first split the data into training, validation, and test sets, and then resample the training set.
It is easy to accidentally resample the validation or test set by first resampling the entire dataset and then breaking it into training, validation, and test sets. That can be problematic for two reasons:
- The validation and test sets are meant to represent the data that the model will encounter when making predictions in production. These sets should match what would be encountered in production as closely as possible. Resampling the validation and test sets generally makes the classification problem appear easier than it actually is and results in misleadingly optimistic evaluations of model performance.
- In addition, care must always be taken to avoid leakage in your evaluation setups. If not done carefully, upsampling and other types of resampling with replacement can result in validation and test sets that contain duplicates of samples from the training set. This too will result in misleadingly optimistic model evaluation and drastically increases the potential for overfitting.
Both of these errors can exaggerate model quality statistics to the extent that a model unfit for deployment may appear to have excellent accuracy. Often the impact of this kind of error is not discovered until the model has been consumed downstream, sometimes for months or more, especially if the only performance statistics being monitored are based on resampled datasets.
What to try instead
If you do end up needing to resample for some reason, construct your validation and test sets before resampling — not after.
Keep clean separations between all dataset divides and be especially diligent regarding the test set — it’s the final gateway to testing viability of a model in production. This should be as close to production expectations as possible.
We illustrate the downside of validating on a resampled dataset compared to our best practices recommendation of doing no resampling in the experiment “Evaluating on resampled datasets” in the linked notebook. In this experiment, we train a model and choose a decision threshold that maximizes f1 score on our validation set, and then apply it to our “production” dataset. We do this twice; once with a training set we resample before splitting out validation, and once with a training set in which we do no resampling.
Here is a table of results:
The first thing to note is that the evaluation metrics on the resampled validation set look much better than those same metrics as computed on the non-resampled validation set. For example, the f1 score on the resampled validation set is 79% compared to 50% on the non-resampled validation set. This is the prime motivating factor behind making this mistake. These statistics look much better — but appearances can be deceiving. These are spurious results.
Recall that these models were selected to maximize the f1 score. We can see the “true” performance of the model reflected in the “prod_f1_score”. The “prod” dataset was split off before any data science was done and is reflective of the data the model would see after deployment.
Despite the fact that the model trained and evaluated on the resampled dataset appeared to achieve an f1 score of 79%, when evaluated on a representative test set we find that it’s substantially worse at 38%. The resampled validation set statistics were drastically overoptimistic. Imagine making a business decision on the basis of a statistic this misaligned with reality!
On the other hand, the model trained on the untouched dataset was estimated to achieve a 50% f1 score based on validation, which it essentially achieved when evaluated on a representative test set (49%).
Without resampling, not only was the estimated performance of the model using a validation set much closer to the reality we achieved in “prod”, but the model is also substantially better — an f1 score of 49% vs 38%.
Use cases for resampling or reweighting data
Although this article clearly discourages resampling imbalanced data, there may be certain instances where resampling or reweighting could be appropriate. Generally, there may be good reason to resample your data if you are encountering resource constraints (like time or memory) or if your optimizer is struggling.
For example, if there is a massive amount of data that is too cumbersome to process, it may be acceptable to downsample the majority class to reduce the train time. This may result in a drop in performance because there is less data to learn from. But, if train time is a priority, a slight decrease in performance could be worth it. However, we’d still want to be sure we only downsample the training data, leaving the validation and test sets untouched. In addition, if calibration is important, we’d recommend reweighting the dataset to assess resampling’s effect on the model calibration.
An example of where upsampling may be appropriate is when it is possible to create new examples of the minority class through data augmentation. Instead of upsampling through exactly duplicating data points, data augmentation is the process of creating new artificial data points from the minority class. A popular technique for doing this is called Synthetic Minority Oversampling Technique (SMOTE). If you’re operating in a regime of data scarcity and could benefit from some extra samples, sometimes data augmentation can be helpful. Data augmentation is not meant to fix class imbalance, but instead to fix the issue of data scarcity where there are not enough samples from the minority class to learn from. Again, we’d recommend that you counteract the augmentation’s effect on class balance with appropriate reweighting to ensure continued model calibration.
Reweighting also has its place. In particular, there are methods that may require a sample-specific weighting to reflect the noise in the target value. There may also be cost functions that are easiest to implement as reweighted versions of pre-existing cost functions. Reweighting can also be used to undo the miscalibration brought on by resampling.
Let’s summarize with a summary that is printable for your cork board: