Balancing Act: A Guide to Undersampling for Improved Machine Learning Performance

Nirajan Acharya
3 min readMar 6, 2024

--

Under-sampling is a technique used to address the issue of class imbalance in datasets. Class imbalance occurs when one class (the minority class) is significantly underrepresented compared to the other class(es) in the dataset. This can lead to biased models that perform poorly on the minority class. Under-sampling aims to balance the dataset by reducing the number of instances in the majority class to match the number of instances in the minority class.

Unbalanced dataset

The algorithm for under-sampling involves randomly selecting a subset of instances from the majority class such that the number of instances in the majority class equals the number of instances in the minority class. This can be done using various methods such as random under-sampling, where instances from the majority class are randomly removed until the desired balance is achieved.

Algorithm:

Step 1 -> Start with a dataset that exhibits class imbalance, where one class (the minority class) has significantly fewer instances than the other class (the majority class).

Step 2-> Determine the desired balance ratio between the minority and majority classes. This ratio is often 1:1, meaning the number of instances in the minority class will be equal to the number of instances in the majority class after undersampling.

Step 3-> Randomly select instances from the majority class without replacement until the number of instances in the majority class matches the number of instances in the minority class.

Step 4-> Repeat the process of random selection until the desired balance is achieved.

Step 5-> The resulting dataset is now balanced and can be used for model training and evaluation.

Before — — — — — — — — — — — — — — — — Afer

Trade-Off

One of the key points to note about under-sampling is that it leads to a loss of information since it reduces the size of the dataset by discarding instances from the majority class. However, this trade-off is often necessary to prevent the model from being biased towards the majority class.

Importance of k-fold cross-validation with under-sampling:

While under-sampling helps to balance the dataset, it’s important to evaluate the performance of the model in a robust manner. One common technique used in machine learning for this purpose is k-fold cross-validation.

K-fold cross-validation involves splitting the dataset into k subsets (or folds) of equal size. The model is trained k times, each time using k-1 folds for training and the remaining fold for validation. This process is repeated k times, with each fold used exactly once as the validation data. The performance metrics are then averaged across all k folds to obtain a more reliable estimate of the model’s performance.

When using under-sampling, k-fold cross-validation is particularly important because it helps to ensure that the performance metrics are not biased by the specific subset of instances selected during under-sampling. By averaging the performance metrics across multiple folds, k-fold cross-validation provides a more accurate assessment of the model’s generalization ability on unseen data.

Example:

Lets us suppose,we have a dataframe ‘dset1’. The dataset ‘dset1’ contains information about different types of plants, with some plant types being more abundant than others. We define a target count (minimum) for the number of instances in each class (in this case, 8082). The code iterates through each unique plant type, checking if the number of instances exceeds the target count. If so, it randomly selects a subset of instances equal to the target count. If the number of instances is less than or equal to the target count, all instances for that plant type are retained. Finally, the under-sampled datasets for each plant type are concatenated into a single dataset, ‘undersampled_dset1’, which is then saved as a CSV file for further analysis.

import pandas as pd

target_count = 8082
undersampled_data = []

for plant_type in dset1['Plant Type'].unique():
plant_type_data = dset1[dset1['Plant Type'] == plant_type]
if len(plant_type_data) > target_count:
undersampled_data.append(plant_type_data.sample(n=target_count, random_state=42))
else:
undersampled_data.append(plant_type_data)

undersampled_dset1 = pd.concat(undersampled_data)
undersampled_dset1.reset_index(drop=True, inplace=True)

undersampled_dset1.to_csv('undersampled_dset1.csv', index=False)

--

--