Hands-on Tutorials
How to use metric learning: embedding is all you need
Find out what’s the difference between regular classification and metric learning, and try it for yourself with my implementation of Supervised Contrastive Loss in PyTorch
One of the most simplest and common tasks in machine learning is classification. For instance, in computer vision you want to be able to fine-tune the last layers of common convolutional neural networks (CNN) to correctly classify samples into some categories (classes). However, there are several fundamentally different ways to achieve that.
Metric learning is one of them, and today I would like to share with you how to correctly use it. In order to make things practical we’re going to look at Supervised Contrastive Learning (SupCon), which is a part of Contrastive Learning, which, in turn, is a part of Metric Learning, but more on that later.
The complete code can be found in the GitHub repo.
How classification is usually done
Before we dive into metric learning, it’s actually a good idea to first understand how the task of classification is usually solved. One of the most important ideas of practical computer vision today is convolutional neural networks, and they consist of 2 parts: encoder and head (in this case — classifier).
First — you take an image and compute a set of features that captures the important qualities of that image. This is done using convolution and pooling operations (that’s why it’s called convolutional neural networks). After that — you unsqueeze those features into the single vector, and use a regular fully-connected neural network to perform the classification. In practice, you take some model (for example, ResNet, DenseNet, EfficientNet, etc) that is pretrained on a big dataset (like ImageNet), and fine-tune it on your task (either just the last layers, or the whole model).
However, there are several things to note here. First of all, usually you only care about the outputs of FC part of the network. That is, you take its outputs, and supply them to the loss function in order to keep the model learning. In order words, you don’t really care what happens in the middle of the network (for example, with features from the encoder). Second of all, (again, usually) you train this whole thing with some basic loss function like Cross-Entropy.
In order to gain a better intuition about this 2 step process (encoder + FC) you can think about it as follows: encoder maps an image into some high dimensional space (for example, in case of ResNet18 we’re talking about 512 dimensions, and for Resnet101– 2048). After that, the objective of FC is to draw a line between these dots that represent samples in order to map them to classes. And these 2 things are trained at the same time. So you’re trying to optimize features and “drawing the line in high dimensional space” jointly.
What’s wrong with that approach? Well, nothing, really. It actually works just fine. But it doesn’t mean that there isn’t another way.
Metric Learning
One of the most interesting ideas (at least personally for me) in the modern machine learning is called metric learning (or deep metric learning). In simple terms: what if, instead of going for the outputs of FC layer, we take a closer look at features that are generated by the encoder. What if we manage to optimize those features with some loss function, rather that the logits from the end of the network. And this is actually what metric learning is about: generating good features (embeddings) with an encoder.
But what does it mean — “good”? Well, if you think about it, in case of computer vision you want to have similar features for images that are similar, and very distinct features for images that are nothing alike.
Supervised Contrastive Learning
Okay, suppose in Metric Learning all we care about is good features. But what’s the deal with Supervised Contrastive Learning? To be honest, there is nothing that special about this specific approach. It’s just a fairly recent paper that proposed some nice tricks, and an interesting 2 step approach:
- Train a good encoder, that is capable of generating good features for images.
- Freeze the encoder, add a FC layer, and train just that.
You might be wondering what’s the difference with regular classifier training then. The difference is that in regular training you train encoder and FC at the same time. On the other hand, here, you first train a decent encoder, then freeze it (don’t train anymore), and train only FC. The intuition behind that logic is that if we manage to first generate really good features for images, it should be easy to optimize FC (whose objective, as we noted earlier, is to optimize the line that separates the samples).
Details of the training process
Let’s dig into the details of SupCon implementation.
Before checking the training loop, one thing you should understand about SupCon is what model is being trained. That’s quite simple: encoder (like ResNet, DenseNet, EffNet, etc), but without regular FC layers for classification.
Instead of classification head, here we have a projection head. Projection head is a sequence of 2 FC layers that maps encoder features to a lower dimensional space (usually 128 dimensional, you can even see that value on the picture above). The reason for projection head is that it’s simply easier for model to learn with 128 carefully selected features rather than all several thousand features that come from the encoder.
All right, time to finally check the training loop.
- Construct a batch of N images. Unlike other Metric Learning approaches, you don’t need to care too much about the choice of those samples. Just take as many as you can, the rest will be handled by the loss.
- Forward those images thought the network in pairs, where a pair is constructed as [augmentation(image_i), augmentation(image_i)], get embeddings. Normalize them.
- Take some image as an anchor. Find all the images of the same class in the batch. Use them as positive samples. Find all the images of difference classes. Use them as negative samples.
- Apply SupCon loss to the normalized embeddings, making positive samples closer to each other, and at the same time — more apart from negative samples.
- After the training is done, delete projection head, and add FC on top of encoder (just like in the regular classification training). Freeze the encoder, and fine-tune the FC.
Several thing to keep in mind here. First, after the training is done, it’s more profitable to get rid of the projection head, and use features before it. Authors explain this fact due to the loss of information in the head, since we lower the embedding’s size. Second, the choice of augmentations is important. Authors propose a combination of cropping and color jittering. Third, Supcon deals with all the images in the batch at once (so, no need to construct pairs or triplets). And the more images are in the batch, the easier it is for model to learn (since SupCon has a quality of implicit positive and negative hard mining). Fourth, you can actually stop at step 4. Meaning that it’s possible to do classification just with embeddings, without any FC layers. In order to do that, compute embeddings for all the train samples. Then, at validation time, for each sample compute an embedding, compare it to every train embedding (compare = cosine distance), take the most similar, take its class.
PyTorch implementation
There is actually a semi-official implementation of SupCon in PyTorch. Unfortunately, it contains very irritating hidden bugs. One of the most serious ones: the creator of the repo used his own implementations of ResNets, and due to some bugs in it, the batch size is twice lower than it can be with regular torchvision models. On top of that — the repo has no validation or visualizations, so you have no idea when to stop training. In my repo I fix all these issues, and add more tricks for a stable training.
To be more precise, in my implementation you have access to:
- Augmentations with albumentations
- Yaml configs
- t-SNE visualizations
- 2-step validation (for features before and after the projection head) using metrics like AMI, NMI, mAP, precision_at_1, etc PyTorch Metric Learning.
- Exponential Moving Average for a more stable training, and Stochastic Moving Average for a better generalization and just overall performance.
- Automatic Mixed Precision training in order to be able to train with a bigger batch size (roughly by a factor of 2).
- LabelSmoothing loss, and LRFinder for the second stage of the training (FC).
- Support of timm models and jettify optimizers
- Fixing the seeds in order to make the training deterministic.
- Saving weights based on validation, logs — to regular .txt files, as well as TensorBoard logs for future examination.
From the box the repo has support for Cifar10, and Cifar100. However, it’s quite trivial to add your own datasets. In order to run the whole pipleline, do the following:
python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage1.ymlpython swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage1.ymlpython train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.ymlpython swa.py --config_name configs/train/swa_supcon_resnet18_cifar100_stage2.yml
After that you can check t-SNE visualizations in the corresponding Jupyter notebook. For example, for Cifar10 and Cifar100 you can expect something as follows:
Final thoughts
Metric Learning is a very powerful thing. However, it’s been quite hard for me to reach the level of accuracy that regular CE/LabelSmoothing can provide. Moreover, it also may be computationally expensive and unstable during training. I tested SupCon and other metric losses on variety of tasks (classification, out of distribution predictions, generalization to new classes, etc), and the advantage of using something like SupCon is unclear.
Wait a minute, what’s the point then? I personally see 2 things here. First — SupCon (and other Metric Learning approaches) still can provide more structured clusters, than CE, since it directly optimizes that property. Second — it’s still very beneficial to have one more skill/tool that you might try. So it’s possible that with a better set of augmentations, or a different dataset (maybe with more fine-grained classes) SupCon can yield better results, not just on par with regular classification training.
So we just have to try and experiment. No free lunch, right?