Multi-Label Classification — Data Leakage, Model Testing, and Addressing Class Imbalance

Best practices for working with imbalanced datasets

Najia Gul
The Startup
6 min readDec 15, 2020


Multi-label classification falls under the realm of Multi-task learning. It is crucial to point out that multi-label classification and multi-class classification problems are not the same.

Multi-label classification involves predicting zero or more class labels. Unlike normal classification tasks where class labels are mutually exclusive, multi-label classification requires specialized machine learning algorithms that support predicting multiple mutually non-exclusive classes or labels.

In this article, we’ll be discussing common key challenges faced when dealing with multi-label classification problems. We’ll be using a real-world dataset of medical chest x-ray which contains 112,120 frontal-view X-ray images of patients with the text-mined 14 disease image labels (where each image can have multi-labels), mined from the associated radiological reports.

(All of the ideas presented in this article are directly inspired from the course ‘AI for Medical Diagnosis’ by

You can download the dataset from here. (I have done some preprocessing and one-hot encoded the labels. The detail code can be found here)

A slice of the dataset with one-hot encoded labels.
A slice of the dataset with one-hot encoded labels

Class Imbalance in Multi-task learning

One of the key challenges when working with real-world datasets is the large imbalance among different classes. This problem is predominant in many scenarios like fraudulent transactions, medical diseases and spam filtering. For eg. a dataset of a rare medical disease will have most examples of healthy patients and very few examples of the rare disease.

Let’s plot the frequency of each of the labels to visualize this imbalance:

Frequency distribution of 14 pathological diseases using x-ray images

We can see from this plot that the prevalence of positive cases varies significantly across the different pathologies.

  • The Hernia pathology has the greatest imbalance with the proportion of positive training cases being about 0.3%.
  • But even the Infiltration pathology, which has the least amount of imbalance, has only 16% of the training cases labelled positive.

The problem of class imbalance can be addressed with many different approaches which are widely used such as:

  • Undersampling the normal class
  • Oversampling the rare class
  • SMOTE (Synthetic Minority Oversampling Technique)
  • Weighted loss

For this problem, let’s use a weighted loss function to penalize the loss dominated by negative labels.

The cross-entropy loss for a training example ‘i’ with a given disease being positive or negative is given as:

The overall, average cross-entropy loss over the entire training set D of size N will then be:

If we take a closer look at this loss function, we see that when most of the training examples contain negative labels, the loss will be dominated by the negative class i.e. the class with chest x-rays of normal patients. Most of the contribution to the loss will be from normal examples of patients, making the algorithm optimize for those examples, and not giving much relative weight to examples with diseases.

If we modify the loss function by weighting the normal and disease classes differently, it will result in an equal contribution to the overall loss.

The current contribution of each class is calculated as:

Let’s write a function to calculate the frequency of each class.

def frequency(labels):N = labels.shape[0]
positive_frequencies = np.sum(labels, axis=0) / labels.shape[0]
negative_frequencies = 1 - positive_frequencies
return positive_frequencies, negative_frequencies

Now, let’s plot a bar chart to visualize the two ratios of positive and negative labels next to each other for each of the pathologies:

pos_freq, neg_freq = frequency(train[classes])ratios = pd.DataFrame({"Class": classes, "Label": "Positive", "Value": pos_freq})ratios = ratios.append([{"Class": classes[l], "Label": "Negative", "Value": v} for l,v in enumerate(neg_freq)], ignore_index=True)plt.xticks(rotation=90)
f = sns.barplot(x="Class", y="Value", hue="Label" ,data=ratios)

Now, by multiplying each example from each class by a class-specific weight factor, 𝑤𝑝𝑜𝑠 and 𝑤𝑛𝑒𝑔, we can have the overall contribution of each class the same i.e. we want: 𝑤𝑝𝑜𝑠×𝑓𝑟𝑒𝑞𝑝 = 𝑤𝑛𝑒𝑔×𝑓𝑟𝑒𝑞𝑛, which we can achieve by taking:

𝑤𝑝𝑜𝑠 = 𝑓𝑟𝑒𝑞𝑛𝑒𝑔

𝑤𝑛𝑒𝑔 = 𝑓𝑟𝑒𝑞𝑝𝑜𝑠

pos_weights = neg_freq
neg_weights = pos_freq
pos_contribution = pos_freq * pos_weights
neg_contribution = neg_freq * neg_weights

Visually, we want our graph to look like:

ratios = pd.DataFrame({"Class": classes, "Label": "Positive", "Value": pos_contribution})ratios = ratios.append([{"Class": classes[l], "Label": "Negative", "Value": v}for l,v in enumerate(neg_contribution)], ignore_index=True)plt.xticks(rotation=90)sns.barplot(x="Class", y="Value", hue="Label" ,data=ratios);

After computing these weights, our final loss for each training case will be:

Data Leakage

When splitting a dataset into train, validation and test sets, it is fundamentally important to make these sets independent of each other.

To illustrate this with an example, consider the time when a patient with a patient ID 1 visits for an x-ray scan. Moreover, suppose the patient is wearing a necklace at the time when the x-ray is taken. The same patient visits the doctor at another time, wearing the same necklace. Now the dataset contains 2 images of the same patient wearing a necklace. When we split the dataset naively, it might be the case that one image is part of the training set, while the other is the part of the test set.

Patient overlap (©

The problem here is that it is possible that the mode may have actually memorized to output a certain label (such as ‘normal’) given that it finds a necklace. Deep learning models can unintentionally memorize rare or unique aspects of data giving an ‘overly optimistic’ performance on the test set.

When we unknowingly pass information from our train dataset to our test dataset, we may get an overly optimistic performance which may not mirror the actual quality of our model.

One way to tackle this problem is to split the dataset by patients’ IDs. This way, all the x-rays belonging to a patient will be in the same set and there will be no patient overlap.

Let’s write a function to check for data leakage.

def leakage(df1, df2, patient_col):df1_unique = set(df1[patient_col].unique().tolist())
df2_unique = set(df2[patient_col].unique().tolist())
patients_common = df1_unique.intersection(df2_unique)#leakage contains true if there is patient overlap, otherwise false.leakage = len(patients_common) >= 1return leakage

After splitting the dataset into train and test, we can check if any patient ID occurs in both of these sets.

leakage(train, test, 'Patient ID')

This should output false.

Model Testing

How should one go about sampling a test set? We have so far seen that naively splitting dataset into train and test sets isn’t that great of an idea. But even after we split our dataset w.r.t. patients’ IDs, we may get a test set which may not contain a good representation of our data. For instance, if we sample 10% of our dataset, it may be the case that our test set may not contain any examples of the classes with the lowest frequencies (such as the class Hernia). This means that we’ll never be able to test our model’s performance on those diseases.

One way to solve this problem when creating test sets is to sample such that we have at-least X% of the minority classes for testing. (A good choice for X would be 50%). This ensures that we have sufficient number of examples to test for our model’s performance on all of the diseases.


We have addressed a few of the most critical problems and fundamentally important practices we should follow when working on multi-label classification problems.

Here is the link to the notebook

Thanks for reading!