Challenges of Training Models on Medical Data
Techniques to tackle Class Imbalance, Multi-Task, and Dataset Size
Amongst the many problems faced during training algorithms on medical datasets, these three are most common:
- Class Imbalance challenge
- Multi-Task challenge
- Dataset Size challenge
For each of these problems, I will share a few techniques to tackle them. So let’s start with them one by one!
Class Imbalance challenge
In the real world, we see a lot more healthy people than diseased people and this is reflected in medical datasets as well. There is not an equal distribution of the number of examples of healthy and diseased classes. This is a reflection of the prevalence or the real-world frequency of disease. In not just medical datasets but also datasets for credit card fraud, you might see a hundred times as many normal examples as abnormal examples.
As a result, it is easy to be tricked into the illusion of the model performing very well whereas it really isn’t doing so. This can happen if simple metrics like accuracy_score are used. Accuracy isn’t a great metric for this kind of datasets since the labels are heavily skewed, so a neural network that just outputs normal would get slightly over 90% accuracy.
We could define more useful metrics such as F1 score or Precision/Recall. Precision is defined as the number of True Positives divided by the number of True Positives and False Positives. It is a good metric to use when the cost of False Positives is high. Recall on the other hand is defined as the number of True Positives divided by the number of True Positives and the number of False Negatives. It is a good metric to use when the cost of False Negatives is high. This is the case with most models in the medical field. However, often we need to take into consideration both False Positives and False Negatives and that’s what F1 score does. It strikes a balance between Precision and Recall and is given by the formula 2 * ((Precision*Recall) / (Precision+Recall)).
Another popular technique to deal with class imbalance is something called Resampling. It is the act of either removing examples from the majority class (known as under sampling) or adding examples to the minority class (known as over sampling) in order to strike a balance between both the classes. However, they have their own set of demerits like information loss and overfitting respectively although they can be solved by using complex resampling techniques.
In the real world, usually predicting just healthy or diseased is not enough. We often need to classify medical data into multiple classes or labels. For example, just detecting Arrhythmia from the cardiac rhythms won’t be as useful as knowing what type of rhythm is actually detected in order to classify as Arrhythmia. It may be Atrial Fibrillation, Supraventricular Tachycardia or any other types of rhythm like Atrial Flutter.
In theory, it is possible to train separate neural network models for each label that we need to classify however that becomes highly impractical to code. It would be nice if we could combine all these classification models into a single deep neural network returning multiple predictions.
We use something called Multi-class classification or Multi-label classification to tackle this challenge and they slightly differ from each other. In Multi-class the classes of the data samples are mutually exclusive whereas in Multi-label the data samples can belong to multiple classes. In the medical field, we generally use Multi-label classification since if a patient is diagnosed with Atelectasis then it doesn’t mean he/she can’t have Cardiomegaly. We pass the scores from the last layer of our model through a sigmoid activation function in the final layer. This converts each score of the last layer to a value between 0 and 1 independent of what the other scores are.
The loss function of our choice becomes binary_crossentropy for Multi-label classification where each label is treated as an independent Bernoulli distribution since we used a sigmoid activation function. In cases where we need Multi-class instead, we can replace the sigmoid activation function with a softmax layer with loss function set to categorical_crossentropy.
Dataset Size challenge
A major challenge of working on medical datasets is the size of those datasets. Large training data plays an important role besides a good architecture in a model’s performance and often the number of patient data available for a disease is not just enough. Low dataset size is a major contributor to both high bias and high variance. It leads to difficulty in generalization and optimization of the model.
To address the difficulty in optimization of the model we use something called Transfer Learning where we use the learning from the relevant lower layer features to train the upper layers without them requiring to again learn from scratch. Due to its previous training, the lower layers can be used as a good feature extractor and thus we can just change the final layers and fine-tune it according to our dataset. This technique makes optimization faster and reduces the amount of data required for training new models.
To address the difficulty in generalization of the model we use something called Data Augmentation where we do some little random transformations to our examples before feeding it to the model instead of duplicating the same examples. In this way, we can make the model invariant to insignificant changes like size or illumination when images are taken into consideration. Practices like flipping the images horizontally or vertically, changing brightness or contrast of the images, rotating or zooming the images to a certain extent all help in data augmentation. This technique is useful in avoiding overfitting in small datasets.