Domain adaptation with an adversarial algorithm for blood cell classification

Hematology is the study of blood, blood-forming tissues, and blood diseases, and accurate diagnosis is critical for the effective treatment of blood disorders. One of the main duties of hematologists is classification of blood cells: doctors analyze blood smears of their patients and evaluate content of pathological blood cells that might hint at leukemia, anemia and other diseases. In practice this tedious task is more often than not performed manually, but it clearly lends itself to modern image recognition technology.

Traditionally, diagnostic and prognostic tools in hematology have been trained on relatively small and homogenous datasets. However, these datasets may not accurately reflect the diversity and complexity of real-world patient populations. This can lead to lower accuracy and less effective treatment decisions. Moreover, images coming from different labs vary in sharpness, brightness, contrast, scale and other properties. Therefore, one aims to develop an algorithm that would be agnostic towards these secondary factors and can confidently discriminate cell images regardless of their origin.

This problem makes a great use case for domain adaptation. Domain adaptation techniques can help to improve the generalizability of diagnostic and prognostic tools by allowing them to be trained on larger and more diverse datasets, which reduces the error rate and improves the overall accuracy of these tools.

Additionally, domain adaptation techniques can also be useful for addressing the problem of imbalanced datasets in hematology. Imbalanced datasets, where the number of samples in one class is significantly higher than the number in another class, can cause diagnostic and prognostic tools to be biased towards the majority class. Domain adaptation techniques can help to correct this bias and improve the performance of these tools on imbalanced datasets.

Dataset

We are thus facing an unsupervised domain adaptation problem. This stands in contrast with the last year’s VisDA challenge where the labels for the validation set were present. Therefore, we aim to build a model that will simultaneously learn correlations in the source datasets while maintaining the ability to extrapolate onto the target dataset.

Blood cell images from the source datasets (left and center) and target dataset (right).

We will call the dataset by the name of the first author who published the data. Below you can find a short description and the link to the original papers:

  • Acevedo_20 dataset: the dataset (Acevedo et al., 2020) contains a total of 17,092 images of individual normal cells, acquired using the automatic analyzer CellaVision DM96, in the Core Laboratory at the Hospital Clinic of Barcelona. The images were obtained during the period 2015-2019 from blood smears collected from patients without infections, hematologic or oncologic diseases, and free of any pharmacologic treatment at the moment of their blood extraction. The images are in jpg format and the size is 360x363. All the images were obtained in the color space RGB and were annotated by expert clinical pathologists.
  • Matek_19 dataset: the Munich AML Morphology Dataset (Matek et al., 2019) contains 18,365 expert-labeled single-cell images taken from peripheral blood smears of 100 patients diagnosed with Acute Myeloid Leukemia at Munich University Hospital between 2014 and 2017, as well as 100 patients without signs of hematological malignancy. The images were obtained in the color space RGB and their size is 400x400 pixels.

We aim to achieve a high f1 macro score on a third dataset, called WBC.

  • WBC1 dataset (validation set): a small subpart of the WBC dataset. It is unlabeled and can be used for training, evaluation and domain adaptation techniques.
  • WBC2 dataset (test set): a second similar subpart of the WBC dataset.
Class distribution in the source datasets.

Apparently, the class distribution in the datasets is uneven. It roughly corresponds to the actual proportions of white blood cells in blood smears. We have to take that into account, since we aim to maximize the macro f1-score, which gives equal weight to each class regardless of its cardinality.

Preprocessing

CenterCrop(always_apply=False, p=1.0, height=345, width=345)
RandomCrop(always_apply=False, p=1.0, height=224, width=224)
Blur(always_apply=False, p=0.5, blur_limit=(3, 7))
RandomFog(always_apply=False, p=0.5, fog_coef_lower=0.3, fog_coef_upper=1, alpha_coef=0.08)
ColorJitter(always_apply=False, p=0.5, brightness=[0.7, 1.3], contrast=[0.5, 1.5], saturation=[0.5, 1.5], hue=[0.0, 0.0])
Flip(always_apply=False, p=0.5)
Rotate(always_apply=False, p=0.5, limit=(-180, 180) interpolation=1, border_mode=4, value=None, mask_value=None, rotate_method='largest_box', crop_border=False)
RandomScale(always_apply=False, p=0.5, interpolation=1, scale_limit=(-0.19999999999999996, 0.19999999999999996))
Resize(always_apply=False, p=1, height=224, width=224, interpolation=1)
Normalize(always_apply=False, p=1.0, mean=[0.8209, 0.7282, 0.8364], std=[0.1649, 0.2523, 0.0945], max_pixel_value=255.0

Model description

Recall that there are no labels for the target domain in the domain adaptation problems. The idea is that the distributions of scoring functions on source and target domains should not differ considerably. Thus, one aims to introduce a measure of distance between the source distribution P and the target distribution Q, and we denote sampling distribution by P^ and Q^, respectively. This distance should be minimized additionally to the classification loss on the source domain.

The generalization bounds obtained in the paper allow us to reduce the problem of minimization of the error rate on the target domain to the minimization of the sum of the empirical margin loss and empirical margin disparity discrepancy (MDD), which we introduce in the following. The samples (x, y) are drawn from the common distribution D, and we introduce a cut-off function

The margin and margin loss are defined as

Thus, the margin loss favors confident predictions. Next we introduce the measure of discrepancy between two distributions in terms of the margin. We write for the labeling function

Then, for some hypothesis class F we define margin disparity and margin disparity discrepancy (MDD) as

The generalization bound obtained in the paper estimates the error rate on the target domain by the empirical margin loss on the source domain and empirical MDD between source and target distributions:

where the remaining terms on the right-hand side don’t depend on f. Altogether, this bound leads to the objective

Remember that MDD corresponds to the supremum over hypothesis class. Thus, additionally to the actual classifier f, we introduce an auxiliary classifier f′ which should be interpreted as the maximizer from the definition of the MDD. Thus, by introducing an additional feature extractor ψ to balance out the maximizer and minimizer, we can express the optimization problem as a minimax game:

Modification

(We are effectively labeling samples from source domain with 0 and samples from target domain with 1.) The modified error rate and modified discrepancy then read:

The new parameter γexp⁡(ρ) is designed to account for the margin ρ. Ultimately, our minimax optimization problem reads:

The trade-off parameter η allows to modulate the preference between classification and discrepancy loss.

Implementation

Architecture of the adversarial algorithm.
Zhang, Yuchen, et al. “Bridging theory and algorithm for domain adaptation.” International Conference on Machine Learning. PMLR, 2019.

We have resnet18 serving as the feature extractor ψ. Importantly, f is not differentiable with respect to the modified loss parameters, whence the adversarial mechanism is implemented by the gradient reverse layer (GRL), which basically inverts the gradients at the backprop. The GRL also features warm start, which implies that inverted gradient updates coming from the discriminator are ignored at first, and increase at a certain schedule depending on the epoch count. This should allow the backbone model to learn the source dataset first before incorporating the discrepancy loss.

Training

Validation metrics

where the entropy function is given by

and p_i is the softmax output of the i-th sample. Soft neighborhood density (SND) computes entropy of the softmaxed target similarity matrix:

where F is matrix of L^2 normalized target feature vectors, X=F^TF is the similarity matrix, X^ is X with diagonal elements removed, and softmax_τ is the softmax function with temperature τ. Softmax temperature allows to (de-)emphasize the most confident predictions. High SND indicates that each feature is close to other features and entails good clustering. However, one should be careful, since trivial model mapping all inputs into a single cluster will also yield high SND score.

Training setup

As a result of hyperparameter tuning, we settle for the following values of the hyperparameters:

  • trade-off: η=1
  • margin: γ=4
  • SND temperature: β=0.005

To account for imbalanced classes, the classes are weighted by their inverse frequency in the source dataset. The model is trained with a batch size of 32 and stopped early after 50 epochs. The learning rate is set to 0.001 and decays by a factor of 0.1 every 20 epochs. The model is trained on a single Nvidia RTX6000 GPU.

Results

Classification and transfer loss (left) and class accuracy (right) during training.
Entropy and SND during training.

The confusion matrix for the source dataset is presented below (remember that labels for the validation and test datasets are not available). We see that the model is pretty much able to discriminate the classes, while the most misclassifications are between two types of neutrophils. It’s not unexpected, since the two types of neutrophils are very similar in appearance.

Confusion matrix for the source dataset.

Ultimately, we are able to achieve roughly the same macro f1 score and the testing dataset as on the validation dataset, which we used in training for domain adaptation. The score compares with the top submissions of the challenge.

f1 micro and macro scores for the validation set WBC1 used for the domain adaptation and test set WBC2

The code for this article is available on github.

References

  • Musgrave, Kevin, Serge Belongie, and Ser-Nam Lim. “Benchmarking Validation Methods for Unsupervised Domain Adaptation.” arXiv preprint arXiv:2208.07360 (2022).
  • Ganin, Yaroslav, and Victor Lempitsky. “Unsupervised domain adaptation by backpropagation.” International conference on machine learning. PMLR, 2015.
  • Jiang, Junguang, Bo Fu, and Mingsheng Long. “Transfer-learning-library.” (2020).

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Konstantins Starovoitovs

Research assistant at HU Berlin working in stochastic analysis. My interests include quantitative finance, market microstructure and machine learning.