Creating Balanced Multi-Label Datasets for Model Training and Evaluation.

Photo of a slice of rainbow cake
Splitting multi-label data isn’t a piece of cake. 🍰 (Image used under license from Shutterstock)

Splitting multi-label data in a balanced manner is a non-trivial task which has some subtle complexities. In this blog post, I review several algorithm implementations and attempt to find the best solution for the data that I’ve been working with.

Overview:

One reoccurring problem for applied scientists on the GumGum Verity team is trying to get representative splits in our training data for train and validation (holdout) splits. A production model that the Computer Vision team has worked on has ~100 classes which aren’t exclusive (multi-label), and also suffer from class imbalances despite efforts to add more support. An additional complexity is that many of the labels in this data have high rates of co-occurrence.

This blog post compares and contrasts five forms of multi-label data splitting, from a basic random baseline to four sophisticated class aware stratified algorithms. These methods were found while searching for the best multi-label stratified cross validation solutions. All of the sophisticated methods leverage an “iterative stratification” algorithm from the paper: “On the Stratification of Multi-label Data”¹ 2011 by Sechidis et al.

One call-out that I think is worth noting up front is that frequently I have heard people quip something to the effect that “if one has a large amount of samples and one randomly splits the data, then things will basically balance themselves out”. Certainly what one quantifies as “large amount of samples” would have a large effect, but my experiments for this project show pretty clearly that random sampling works very poorly on this distribution which has on the order of 100k samples.

This also applies to stratification methods that don’t explicitly maintain disjoint folds. In the source code here for GitHub user ‘trent-b’s² implementation of “MultilabelStratifiedShuffleSplit” it is noted:

Note: like the ShuffleSplit strategy, multilabel stratified random splits do not guarantee that all folds will be different, although this is still very likely for sizeable datasets.

However, again w.r.t. the example data for this blog post, it was found that re-running a non-disjoint algorithm to generate multiple holdout sets can result in very high overlap relative to a random baseline, which is elaborated upon later.

Some multi-label stratified splitting solutions solely rely on generating combinations of True labels to turn the problem into a multiclass (single hot) problem, for which typical StratifiedKFold from sklearn is suitable. Because the example data for this blog post has a relatively large amount of classes, the resulting number of combinations is high, and final combinations would be very sparse. Therefore, for our situation, it didn’t make sense to look at this variant of solution. (The IterativeStratification algorithm from skmultilearn.model_selection, reviewed here later, uses a modification of this label set combinations, albeit in a non-exhaustive manner.) There are more details on this later, but if your problem has much fewer classes and/or combinations, then multi-class conversion may be worth considering.

A quick refresher on multi-label classification:

Figure by author, images licensed from Shutterstock except dog image provided by Divyaa Ravichandran and used with permission.

General goals and requirements:

(These are specific to our business problem, and won’t necessarily apply to other situations.)

  1. Ideally the splitting algorithm would create disjoint folds of train and holdout data so we could test models on different folds before proceeding to evaluations on a final test set. If there was some slight overlap between holdout folds, it might be acceptable. (Seeking excellence, not perfection here.)
  2. In our case, significant overlap between folds is bad for situations like model rapid prototyping where we train on one test set and evaluate on another test set to try to quickly evaluate a model’s potential to solve a problem. If there was significant overlap between folds, that would lead to data leakage which would require other workarounds.
  3. The algorithm should be stratified so that class imbalances in the labels are accounted for, and validation sets should have proportional numbers of samples. In other words, if there’s 10 folds and a small representation label has 500 samples in our data, then ideally 50 samples would be in each validation fold.

Other considerations:

  1. Short runtime or efficiency is not very important to us as we standardize and store data splits for re-use and wouldn’t be regenerating folds frequently.
  2. There are roughly 100k samples in this data, with roughly 100 classes of which multiple classes can be True. (EDA on distributions, etc below.) There are counter factual classes in the data, some of which are the most frequently occurring classes. These counterfactual classes are present to help the model distinguish between a car and a police car, for example. That also means that if an image were a police car, then it could also be labeled as a car.
  3. Some classes have very high co-occurrence with other classes and this also applies so some classes that have relatively few samples. (This is can affect the results of some algorithms a lot.)

A (vague) description of the business problem:

In this particular case, the model is attempting to discern if an image is one or more of many types of threat or safe classes. As noted previously, we’ve added some safe counterfactual classes to help the model generalize for the real world. For example if we had a “mountain” class (safe) but also a “volcano” class (threat), then the image of Mount St. Helens erupting should have True for both classes.

The end result of this mixture of classes is that there are some classes that occur very frequently, and other classes that are relatively infrequent, and that many of these classes have some degree of co-occurrence. Another challenge is that some relatively lower support classes have higher degrees of co-occurrence.

Therefore, we needed a solution which resulted in holdout datasets that were still representative of the overall data and allowed us to evaluate model performance in lower representation classes with more confidence.

More context about the data being split:

While different algorithms and papers were considered in trying to solve this problem, evaluating the different algorithms on different datasets was not performed. Therefore the only absolute conclusions that could be drawn would specifically apply for this particular example data.

However, I’m providing some information about distributions and other metrics on the example data so that readers can evaluate if the observations and findings on this data could also apply to your own. Finally, in Sechidas et al’s paper¹ (Table 2) they did perform different dataset comparisons which may be worth looking at for your problem.

Figure 1

As we can see from Figure 1, the counterfactual classes occur most frequently, usually because those classes could apply to many types of pictures. Conversely, on the right side, we see some classes with relatively low support, which tend to be low threat classes are are less of a business priority.

Label combinations could be converted to one-hot representations that “On the Stratification of Multi-label Data”¹ calls “labelsets”. The following plot shows the distribution of labelsets in the data that occur one or more times:

Figure 2

The main takeaways from Figure 2 are:

  1. There’s 5110 combinations, of which ~62% occur only once.
  2. For labelsets that occur once, the mean label count is ~4, meaning 4 classes have True for that labelset.
  3. For labelsets that occur 2 or more times, the mean label count is ~3, meaning that 3 classes have True for that labelset.

Finally, another confounding issue in the example data is that labels co-occur and some, especially some of the lower representation classes, co-occur very frequently. This affects the iterative stratification algorithm from Sechidis et al¹ because:

“The motivation for this greedy key point of the algorithm, is the following: if rare labels are not examined in priority, then they may be distributed in an undesired way, and this cannot be repaired subsequently.”

For rarely occurring labels that have high co-occurrence with each other, the above greedy algorithm can also “break” the stratification. The rare “unfair” holdout distributions seen in the box plots later can probably be explained by this issue, although it’s a relatively small issue in the grand scheme of things.

The following Figure 3 is a close up of a heatmap in which the data is filtered to the subset where the labels on the left column are True. Then, for each subset, the other labels are filtered to see if they co-occur. (Combination depth is only 2 here, but that covers 75% of all occurring labelsets in the example data.)

In this case, the order of the filtering matters. For example, if we had (made up) class of animal first with 1000 samples, then the second filter cat reduced it to 100, then there would be 10% co-occurrence in the heatmap. However, when the cat label filtered first, when we applied the animal label we would expect the co-occurrence to be 100% if our annotations were perfect.

Figure 3

Figure 3 is a close up heatmap of a subset of classes. The yellow line on the hypotenuse is where the label is itself, so the ratio is 1.0. From this anchor point we can see that obfuscated class c_01 has a relatively high co-occurrence (~0.8 by the color) with cf_04 (counter factual class 4). Now that hopefully it is clear what the heatmap is showing, the following is the heatmap for a larger subset of classes.

Figure 4

Figure 4 then allows us to see that there is structure to the co-occurrences. Vertical lines tend to be commonly occurring counter factual labels. There are some boxlike structures around the yellow hypotenuse, and those are generally formed by threat categories which have many associated classes. For example, violence or medical would have many classes associated with them, and those classes would tend to co-occur.

In conclusion, in reviewing the data, there are many classes of which some classes occur less frequently and are imbalanced. These classes in particular pushed us to find better ways of splitting our data. It should also be noted that we also are attempting to target some of the more problematic classes for further annotation, but care must be taken to minimize the bias introduced by different data harvesting tactics.

Algorithms Reviewed:

First of all, trying to find good resources on stratified splitting of multi-label data isn’t very easy. As one can see from the following screenshot, it’s easy to find questions with large interest about the topic, and yet the topic doesn’t appear to be nearly as solved as multiclass stratification with Stratified KFold algorithms. Therefore, the answers are incomplete or only useful within certain constraints. Certainly one doesn’t expect random answers on StackOverflow to be the end of the search, of course. :)

Figure 6, Screenshot from stackoverflow collected Sept 26th, 2021

This particular question⁴ was posed 3 years ago as of this blog post, and has been viewed 19k times, yet there’s only a couple of answers which may or may not apply to someone’s situation.

For this problem, the following potential solutions were evaluated to see how well they worked on the example data. Besides the random folds, all other algorithms use some implementation of the “iterative stratification”¹ algorithm noted earlier.

Two libraries are used for their implementations. scikit-multilearn .² as well as iterative_stratification .³ which imports as iterstrat.ml_stratifiers . Both libraries have disjoint and non-disjoint implementations, and both from each library were evaluated in case the overlap of the non-disjoint evaluation folds didn’t have much overlap.

Table 1

Note: Alg 1 is simply a random baseline implemented with numpy by the author. The implementation simply creates indices for each sample, shuffles the indices, and then takes non-overlapping samples of size = (num_samples // k_folds) + 1 Rounding up as a shortcut to ensure that there’s no left over samples.

Evaluation of Algorithms:

Fold Variation from Ideal Splits:

The primary form of evaluation for this survey was running 100 simulations of each algorithm and comparing the final distribution of each fold of holdout data as compared to a theoretical perfect split. The comparison was measured in terms of percentage of expected samples in holdout for that class label.

If there are 1000 samples for class A, and there are 10 folds, ideally we would want 100 samples in each holdout fold. Therefore the number of samples in the fold is subtracted by the ideal number of samples, and then the result is divided by the ideal number of samples. This ratio is then converted to a percentage.

result = 100 * ((#_samples - #_ideal_samples) / #_ideal_samples)# example for holdout higher than ideal
# 100 * ((300 - 200) / 200) => +50%
# example for holdout lower than ideal
# 100 * ((150 - 200) / 200) => -25%

The results, by class label (ie column), are recorded for the 100 experimental runs, and then plotted in a box plot for both train and holdout sets. This allows us to see the distribution of how classes performed in many simulations in the different folds.

Test Fold Collisions, How Much Do Folds Overlap?

For Algo #2 and #5, disjoint folds aren’t maintained. Therefore the algorithm was run repeatedly to simulate generating k folds, and also the holdout sets were saved so that intersection of holdout folds could be evaluated for each algorithm implementation. This allows evaluation if the random process actually did product mostly disjoint holdout folds.

Note: For the purposes of these simulations, all algorithms used 7 folds because of constraints of the smallest label support levels.

This repeated fold process was also repeated in a separate experiment to test how much overlap random sampling (variant of Alg 1) without replacement would create. This was done to again create a baseline to compare the other algorithms against, and also to help verify that the code was working as expected.

Results:

Variance of Train Folds from Ideal:

Figure 7

The random sampling results, figure 7, have relatively high variance even for the train sets. This does not bode well for the holdout sets.

Figure 8

iterative_train_test_split, figure 8, shows some minor variance even on the train splits. This is a bit surprising because the iterstrat.ml_stratifiers implementation does not show this much at all.

Figure 9

IterativeStratification from Alg 3, figure 9, shows even more variance. This is a disjoint implementation so there are more constraints than Alg 2 has. However see Alg 4, figure 10, for comparable algorithm and constraints but better results.

Figure 10

MultilabelStratifiedKFold, Figure 10, has very low variance for all of the simulations yet also outputs disjoint folds.

Figure 11

MultilabelStratifedShuffleSplit has the lowest variance of any train set, although it does not create disjoint folds so there are fewer constraints.

Variance of Holdout Folds from Ideal:

Figure 12

Figure 12 has random splits which resulted in terrible performance compared to other methods. The red circles on the plot indicate that the scale had to be increased for this particular set of results. All other box plots in this blog are ±25%, but this plot needed ±60% to show the full range.

Figure 13

iterative_train_test_split from skmultilearn.model_selection has the worst relative performance against the iterative stratification algorithm solutions. While having some classes -20% isn’t necessarily a deal breaker, other solutions do better in less time. This is somewhat more surprising because this algorithm does not create disjoint folds and so has less constraints than some of the competitors here.

Figure 14

Algo 3, in Figure 14, is a relative improvement upon Algo 2 which is a bit strange because it does create disjoint distributions and so has more constraints than Algo 2. However it still doesn’t compare to the iterstrat_ml.stratifiers that follow.

Figure 15

Disjoint folds, extremely low variation from ideal, doesn’t get any better that algo 4 using iterstrat.ml_stratifiers in Figure 15.

Figure 16

While the variation from ideal is nearly zero, MultilabelStratifiedShuffleSplit, Figure 16, doesn’t create disjoint folds. If that isn’t a downside, then this is the best performer on the example data.

Fold Intersection Overlap Distributions:

Figure 17, Overlap is about as expected.

For the random baseline, the overlap of test folds randomly sample is exactly as expected. (Figure 17)

Figure 18

Overlap is very high for Algo 2, using iterative_train_test_split from skmultilearn.model_selection. (Figure 18) It appears that there may be an issue with scikit-multilearn’s implementation of iterative_train_test_split. Multiple experimental runs were performed and also the documentation was checked to ensure that there wasn’t a hard coded seed or something.

Initially I suspected that this may have come from having so many conditions on the random selection, but then the implementation from iterstrat.ml_stratifiers echoed the random baseline, and didn’t show this high rate of intersection.

Figure 19

The distribution in figure 19 is indicative of disjoint folds. (No overlap)

Figure 20

The distribution in figure 20 is also indicative of disjoint folds. (No overlap)

Figure 21

For Algo 5, the MultilabelStratifiedShuffleSplit algorithm doesn’t create disjoint folds, but the overlap is essentially the same as the baseline random simulations. With its best in class splits, it’s a top choice if you do not need disjoint folds.

Conclusion:

Five algorithms were reviewed on the example data. Besides the random baseline, all of the algorithms were different implementations of “iterative stratification”¹, yet the results were quite different.

For my purposes (again trying to get reliable splits that are also disjoint), it appears that the MultilabelStratifiedKFold from the iterstrat.ml_stratifiers library, aka here as Alg 4, is the best choice for the example data. The ability for Alg 4 to create reliable splits quickly, while not appearing to have an issue with high co-occurrence but low representation classes, is extremely impressive. Hats off to Trent B and his contributors for writing such a great implementation.

As a side note, the iterstrat implementation was also more than 10x faster than the skmultilearn implementations. Speed wasn’t a criteria for judgement but it certainly helped in producing this blog post. That being said, reading the implementation with vectorized numpy operations, while very fast, is also fairly opaque as to how exactly things are working, and I didn’t have time to dive deeper and try to understand what they’re doing differently than the sklmultilearn implementation.

Unfortunately I didn’t have nearly as good luck with the sklmultilearn implementations; the variation from ideal was much higher, but also the high overlap for the non-disjoint iterative_train_test_split (Alg 2) holdout folds was concerning, especially since the performance on my example data wasn’t better either.

Verdict:

Table 2

References:

1. Sechidis, K., Tsoumakas, G., & Vlahavas, I. (2011, September). On
the stratification of multi-label data. In Joint European
Conference on Machine Learning and Knowledge Discovery in
Databases
(pp. 145-158). Springer, Berlin, Heidelberg.
2. Szymański, P., & Kajdanowicz, T. (2017). A scikit-based Python
environment for performing multi-label classification. ArXiv
e-prints
. Opgehaal van http://arxiv.org/abs/1702.01460
http://scikit.ml/
3. Trent-B (2018) iterative-stratification [Source Code]
https://github.com/trent-b/iterative-stratification
4. Data Science, StackExchange Retrieved September 25th, 2021 from https://datascience.stackexchange.com%2Fquestions%2F33076%2Fhow-can-i-perform-stratified-sampling-for-multi-label-multi-class-classification

We’re always looking for new talent! View jobs.

Follow us: Facebook | Twitter | | Linkedin | Instagram

Computer Vision Scientist with experience at Macys.com, Amazon.com, and currently GumGum.com