The Startup
Published in

The Startup

Multi-Label Classification — Data Leakage, Model Testing, and Addressing Class Imbalance

Best practices for working with imbalanced datasets

A slice of the dataset with one-hot encoded labels.
A slice of the dataset with one-hot encoded labels

Class Imbalance in Multi-task learning

Frequency distribution of 14 pathological diseases using x-ray images
  • The Hernia pathology has the greatest imbalance with the proportion of positive training cases being about 0.3%.
  • But even the Infiltration pathology, which has the least amount of imbalance, has only 16% of the training cases labelled positive.
  • Undersampling the normal class
  • Oversampling the rare class
  • SMOTE (Synthetic Minority Oversampling Technique)
  • Weighted loss
def frequency(labels):N = labels.shape[0]
positive_frequencies = np.sum(labels, axis=0) / labels.shape[0]
negative_frequencies = 1 - positive_frequencies
return positive_frequencies, negative_frequencies
pos_freq, neg_freq = frequency(train[classes])ratios = pd.DataFrame({"Class": classes, "Label": "Positive", "Value": pos_freq})ratios = ratios.append([{"Class": classes[l], "Label": "Negative", "Value": v} for l,v in enumerate(neg_freq)], ignore_index=True)plt.xticks(rotation=90)
f = sns.barplot(x="Class", y="Value", hue="Label" ,data=ratios)
pos_weights = neg_freq
neg_weights = pos_freq
pos_contribution = pos_freq * pos_weights
neg_contribution = neg_freq * neg_weights
ratios = pd.DataFrame({"Class": classes, "Label": "Positive", "Value": pos_contribution})ratios = ratios.append([{"Class": classes[l], "Label": "Negative", "Value": v}for l,v in enumerate(neg_contribution)], ignore_index=True)plt.xticks(rotation=90)sns.barplot(x="Class", y="Value", hue="Label" ,data=ratios);

Data Leakage

Patient overlap (©
def leakage(df1, df2, patient_col):df1_unique = set(df1[patient_col].unique().tolist())
df2_unique = set(df2[patient_col].unique().tolist())
patients_common = df1_unique.intersection(df2_unique)#leakage contains true if there is patient overlap, otherwise false.leakage = len(patients_common) >= 1return leakage
leakage(train, test, 'Patient ID')

Model Testing




Get smarter at building your thing. Follow to join The Startup’s +8 million monthly readers & +768K followers.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store