Hi, I am Aleksei Shabanov — deep learning engineer, PhD student and one of the Catalyst’s contributors. Today I would like to tell you about a metric learning pipeline, which has been added in 20.08 release.
Note. Catalyst is a PyTorch framework for Deep Learning research and development. You get a training loop with metrics, model checkpointing, advanced logging and distributed training support without the boilerplate. Intro to Catalyst can be found here.
Few words about metric learning
Roughly speaking, the goal of metric learning is to build a feature extractor with the following behaviour: it should map objects from the same class into the nearest points in the feature space and vice versa: distant points for different classes. These objects can be images, texts, sounds and so on. We will not discuss the theoretical side in detail, instead we will focus on the implementation. We designed our framework as a set of connected blocks with specified interfaces (python abstract classes). A user is able to add new logic via inheritance mechanism.
First, we should point out that training and validation stages of metric learning pipeline are completely different (in comparison with more common scenarios like classification when the only differences are augmentations and the source of
DataLoader). So let’s take a look at these stages separately.
The idea of the training stage is as follows: we should sample pairs, triplets and quadruplets from the dataset and then calculate a pair-based, triplet-based or quadruplet-based loss. We’ve implemented a triplet-based scenario, but it can be easily adopted to other tasks.
So let me introduce our design of the training stage which is based on two samplers: the first is to sample batches, the second is to sample triplets from these batches.
- The first one is implemented as
BalanceBatchSamplerwhich puts P instances for K classes into the batch (P, K ≥ 2). This behaviour guarantees that we can always form the triplets inside the batch and overcome classes’ imbalance. Note that
BalanceBatchSamplershould be provided with information about the classes of all the items in the dataset. For this purpose the dataset should be a child of
MetricLearningTrainDataset. (Note, that
BalanceBatchSamplerсan be useful for a classic scenario as well, such as training a classifier.)
- The second one is represented by an abstract class
InBatchTripletsSamplerand its ready-to-use children:
HardClusterSampler. After triplets have been formed, they are used as an input of
TripletMarginLossfrom PyTorch. The in-batch sampling and loss calculation are united in
Note. Intuitively, it might seem that it is easier to implement
Dataset that will return triplets. But this approach has several significant drawbacks. Firstly, it is extremely ineffective, this is due to that fact that the hardest part of computing is feature extracting. If we have already extracted features from the objects, we would like to use as many triplets as possible, because triplets selecting is a lightweight procedure. But if the triplets are selected in advance, we will lose a large number of potential triplets. Secondly, such a mode of triplets selection as ‘hard’ cannot be performed if the search is carried out within the entire dataset. At the same time, with
InBatchSampler you can search for the hard samples in a reasonable time.
For model evaluation query/gallery protocol is usually used. It means that we should split a validation dataset to query and gallery parts, calculate a distance matrix based on extracted features and apply one of the retrieval metrics. The idea is that the metric will be better if the gallery elements closest to the query have the same class. If you work with academic retrieval datasets, then along with them is attached information on the exact query/gallery split. If you work with your custom dataset the split should be done based on domain knowledge. Note, that MNIST dataset is not a retrieval dataset, that is why we just simply pick 20% of it as a query.
To provide information about query/gallery split of your dataset it should be implemented as a child of
QueryGalleryDataset, you can use
MnistQGDatset as an example. With reference to the metric, it can be implemented as a new callback or you can pick ready-to-use
CMCScoreCallback. The last one accumulates features of all the samples from the validation loader, then builds a matrix of distances between queries and galleries and finally calculates Cumulative Matching Characteristic (CMC).
The whole pipeline
Before putting things together via
SupervisedRunner, let us focus on two facts:
- Since the designs of training and validation stages are different, we cannot call
val_loaderand also apply
train_loader. The required behavior is provided by
- We define training epoch as a process by which we go through all the classes from the training dataset (instead of all the samples in the classical scenario). It makes these training epoch very quick. At the same time, validation epoch includes all the samples in the query/gallery dataset. As a result we should run the validation stage much less often than the training stage.
PeriodicLoaderCallbackis used for this.
To sum it up, the complete pipeline is presented below. It can also be found in Catalyst’s Readme as a minimal example called CV — MNIST with Metric Learning.
I hope this tutorial was useful to you. More details are available in the documentation and minimal examples. Do not hesitate to adopt this code to your task and ask any questions in our slack community I also especially thank Nikita, Julia, and Sergey for their help and advices during metric learning feature-release. See you in the next posts!