How to Ensure Image Dataset Quality for Image Classification?
Every project in Computer Vision starts with Data collection strategy and Dataset creation. For the past years, people have been focusing on model development without investing in Training Data creation as much as needed. In fact, creating an Image Dataset was truly complicated and time-consuming, and was often done by engineers or interns in a rather inefficient way. This pain led the way for a generation of labeling tools startups and open-source tools such as:
- Labelbox
- Supervise.ly
- CVAT
- V7 Labs
- Rectlabel
- Picsellia — and of course, our MLOps platform contains some pretty cool labeling tools too ;)
These times are gone. Labeling tools are becoming a commodity, AI companies can now leverage a multitude of tools and services to get their Datasets created.
Unfortunately, many enterprises are still experiencing issues with their AI models performance.
Most of these problems come from the Dataset quality itself. And, by quality we mean:
- The quantity of data in the dataset
- The amount of mislabeled data in your dataset
- The relevancy of the images inside the dataset
In this blog post, we’ll go through these 3 points and how to monitor them in order to optimize the overall quality of your Dataset for your use-cases. We’ll be focusing on an image classification as an example.
How to Know if you Have Enough Images in Your Dataset?
1 — Empirical Rules to Determine the Minimum Number of Images
It can be complicated to determine the number of images needed in your Image Dataset for an Image Classification task. However, there are some good rules of thumb that you can follow. According to one of them, around 1000 examples by class are a decent amount to start with.
But, these types of rules are not strictly “data-science-ish”. Let’s take the rule of 1000 as an example (let’s call it that way). It can be completely wrong if you consider transfer learning.
Fortunately, there are some more robust ways to determine if you have the right amount of data in your training set. The first one would be Sample-Size Determination Methodology by Balki et al. explained by Keras.io.
https://keras.io/examples/keras_recipes/sample_size_estimate/
2 — Sample-Size Determination Methodology Explained
“A systematic review of Sample-Size Determination Methodologies (Balki et al.) provides examples of several sample-size determination methods. In this example, a balanced subsampling scheme is used to determine the optimal sample size for our model. This is done by selecting a random subsample consisting of Y number of images and training the model using the subsample. The model is then evaluated on an independent test set. This process is repeated N times for each subsample with replacement to allow for the construction of a mean and confidence interval for the observed performance.”
This method is simple to understand but actually effective. In short, it consists of training multiple models N times with an increasingly bigger subset of your datasets (let’s say 5%, 10% 25% and 50%). Once this is done, record the mean accuracy and standard deviation to fit an exponential curve to predict the optimal number of images to obtain a certain accuracy target.
To make it clearer, let’s visualize it with a simple example.
Let’s say we have a training set of 1000 images evenly distributed between cats and dogs.
- Train 5 models on 5% of the set → 50 images and record the accuracy for each of them
Model 1 : [0.3 , 0.33, 0.28, 0.35, 0.26] (List of accuracy for every model)
2. Let’s do the same thing with 10% of the set → 100 images and record the accuracy
3. Repeat this with 20%, 35% and 50% of your set
Now, you should have 5 lists of accuracies corresponding to 5 different training for 5 different sizes of training subsets.
Next, you just have to calculate the average accuracy and standard deviation of each list, and fit an exponential curve over these data points. You should get a curve looking like this.
By looking at the extrapolation of the exponential curve, you should be able to determine if you have enough images in your training set to obtain your accuracy target.
A link to the full methodology and code can be found here.
How to Identify Mislabeled Data in my Dataset?
There are multiple ways to answer this:
- The first one is purely operational → What was the workflow used to label my images?
- The second one is more analytical → How to detect mislabeled data automatically?
1 — Build an Annotation Workflow Made for Top Quality
There are some principles that you need to set before creating a dataset:
- In the face of ambiguity, refuse the temptation to guess
This is inspired by the zen of python. It means that you need to set highly clear guidelines for your annotators. Worst case scenario, is that your annotator decides to take a different decision for an ambiguous class.
2. Three is always better than one
Whenever possible, make sure that multiple people annotate the same images and extract every image that has different labels created, to precisely understand why the annotators disagree on these pictures. Always try to obtain a 100% consensus score. In cases where humans disagree, it’s very likely that your CNN will not perform well either.
3. Have a third person review the dataset — Sorry, I don’t have a good explanation for this one.
Human bias is a real thing. It’s always a good idea to have a third-party actor in your annotation workflow to take care of the review process.
2 — Programmatically Identify Mislabeled Images in your Dataset
Label anomaly can mean multiple things, but the 2 main reasons are mislabeled data and ambiguous classes. There are various methods to extract wrongly labeled or ambiguous data, but for this blog post, we will only go in-depth for one method.
Labelfix, an implementation of “Identifying Mislabeled Instances in Classification Datasets” by Nicolas M. Muller and Karla Markert.
Labelfix explained
In this paper and implementation, the authors present a nonparametric end-to-end pipeline to find mislabeled instances in the numerical, image, and natural language datasets. They evaluate their system quantitatively by adding a small number of label noise to 29 datasets, and show that they find mislabeled instances with an average precision of more than 0.84 when reviewing their system’s top 1% recommendation. Then, they apply their system to publicly available datasets and find mislabeled instances in CIFAR-100, Fashion-MNIST, and others.
To put it into simple words, the labelfix method tries to find a given percentage (user input) of images that are most likely mislabeled. It means that you should be able to specify the number of wrongly labeled images you want to find, and the labelfix algorithms will be able to give you the X% most likely to be mislabeled.
The magic behind this implementation is quite intuitive and can be summarized in 4 steps.
- Train a Classifier on your entire training set, do not keep any images for your test set
- Perform inference on your whole training set with the above trained model
- Perform the inner products < yn, yn”> , where yn is the vector of true label one-hot encoded, and yn” is the predicted vector of probability, for every prediction.
- Sort these inner products and extract the X% first. These are the most likely mislabeled images
Here is a little benchmark of the detection performances the researchers achieved on multiple datasets.
You can find the research paper here: https://arxiv.org/pdf/1912.05283.pdf
And the github repo here: https://github.com/mueller91/labelfix
Now that we have a method to detect mislabeled data, we have one last question to ask. Are my images relevant for my training set ?
How to Assert Data Relevancy in my Training Set?
Having non relevant data in your training set can seriously damage the overall accuracy of your models. This is, if there are duplicates in your training data, they will cause biases. This means that the model learns to be less efficient with new samples and would have a hard time generalizing on unseen data. Duplicates could also cause a strange behavior for your validation set. You might end up with the same images in your training and your validation sets if an image is in your dataset multiple times.
How to Find Duplicate Images in Your Training Set
We will call this trick the “embedding similarity method”, really effective and easy to realize.
The method has 2 steps:
- Compute the embedding for every image
Images store a lot of information in their pixel values. Comparing them will be expensive and might not provide us with high quality outputs. To get better quality results, we can use a pre-trained computer vision model like MobileNet to generate embeddings for each image.
An embedding is obtained through processing an image through deep models to produce a vector containing only a few thousand values that distill the information stored in millions of pixels on an average scale.
You should choose a pre-trained model which is lightweight but still accurate (that’s why mobileNet could be a good choice) so you can extract the embedding of all your images relatively easily.
2. Compute the cosine similarity between all the embeddings
Here is the formula of the cosine similarity calculus:
It’s a more robust way to compute similarity between two vectors than a simple distance.
You can leverage the scikit-learn library which packages an implementation. The output of this algorithm will be a NxN matrix, where N is the number of images in your training set, with values ranging between 0 and 1 where 1 is a total similarity score.
You will then be able to choose the threshold to apply and determine the images that are too similar. A good rule of thumb is to filter out the images that have a 0.9 and above similarity score.
Some takeaways
Nowadays, ML teams no longer have to spend loads of time creating quality image datasets for their image classification projects. The surge of a multitude of open-source and labeling tools have facilitated a great portion of AI computer vision projects across the globe.
However, many AI companies still struggle with AI model performance. That’s why in this article, we covered three of the most common problems of dataset quality, by using image classification as a use-case. We covered the following issues:
- The quantity of data in the dataset
- The amount of mislabeled data in your dataset
- The relevancy of the images inside the dataset
And more importantly, we went through how to monitor them to optimize your dataset quality in your projects.
If you’re interested in trying a solution to ensure your quality datasets, at Picsellia we offer a 14 day trial of our MLOps platform. It covers the whole MLOps life cycle, including advanced image processing and object detection features. Give it a try!