Metric Learning with Catalyst

Aleksey Shabanov
Aug 29, 2020 · 5 min read

Hi, I am Aleksei Shabanov — deep learning engineer, PhD student and one of the Catalyst’s contributors. Today I would like to tell you about a metric learning pipeline, which has been added in 20.08 release.

Note. Catalyst is a PyTorch framework for Deep Learning research and development. You get a training loop with metrics, model checkpointing, advanced logging and distributed training support without the boilerplate. Intro to Catalyst can be found here.

Few words about metric learning

Image for post
Image for post
Figure 0. Feature space of the model trained via metric learning on Fashion-MNIST dataset. © zalandoresearch

Roughly speaking, the goal of metric learning is to build a feature extractor with the following behaviour: it should map objects from the same class into the nearest points in the feature space and vice versa: distant points for different classes. These objects can be images, texts, sounds and so on. We will not discuss the theoretical side in detail, instead we will focus on the implementation. We designed our framework as a set of connected blocks with specified interfaces (python abstract classes). A user is able to add new logic via inheritance mechanism.

First, we should point out that training and validation stages of metric learning pipeline are completely different (in comparison with more common scenarios like classification when the only differences are augmentations and the source of DataLoader). So let’s take a look at these stages separately.

Training stage

Figure 1. Same colors were used for the samples with the same classes.

The idea of the training stage is as follows: we should sample pairs, triplets and quadruplets from the dataset and then calculate a pair-based, triplet-based or quadruplet-based loss. We’ve implemented a triplet-based scenario, but it can be easily adopted to other tasks.

So let me introduce our design of the training stage which is based on two samplers: the first is to sample batches, the second is to sample triplets from these batches.

  1. The first one is implemented asBalanceBatchSampler which puts P instances for K classes into the batch (P, K ≥ 2). This behaviour guarantees that we can always form the triplets inside the batch and overcome classes’ imbalance. Note that BalanceBatchSampler should be provided with information about the classes of all the items in the dataset. For this purpose the dataset should be a child of MetricLearningTrainDataset. (Note, that BalanceBatchSampler сan be useful for a classic scenario as well, such as training a classifier.)
  2. The second one is represented by an abstract class InBatchTripletsSampler and its ready-to-use children: AllTripletsSampler , HardTripletsSampler and HardClusterSampler. After triplets have been formed, they are used as an input of TripletMarginLoss from PyTorch. The in-batch sampling and loss calculation are united in TripletMarginLossWithSampler.

Note. Intuitively, it might seem that it is easier to implement Dataset that will return triplets. But this approach has several significant drawbacks. Firstly, it is extremely ineffective, this is due to that fact that the hardest part of computing is feature extracting. If we have already extracted features from the objects, we would like to use as many triplets as possible, because triplets selecting is a lightweight procedure. But if the triplets are selected in advance, we will lose a large number of potential triplets. Secondly, such a mode of triplets selection as ‘hard’ cannot be performed if the search is carried out within the entire dataset. At the same time, with InBatchSampler you can search for the hard samples in a reasonable time.

Validation stage

Image for post
Image for post
Figure 2. Validation stage. Colors show classes, circles and squares show query and gallery parts.

For model evaluation query/gallery protocol is usually used. It means that we should split a validation dataset to query and gallery parts, calculate a distance matrix based on extracted features and apply one of the retrieval metrics. The idea is that the metric will be better if the gallery elements closest to the query have the same class. If you work with academic retrieval datasets, then along with them is attached information on the exact query/gallery split. If you work with your custom dataset the split should be done based on domain knowledge. Note, that MNIST dataset is not a retrieval dataset, that is why we just simply pick 20% of it as a query.

To provide information about query/gallery split of your dataset it should be implemented as a child of QueryGalleryDataset, you can use MnistQGDatset as an example. With reference to the metric, it can be implemented as a new callback or you can pick ready-to-useCMCScoreCallback. The last one accumulates features of all the samples from the validation loader, then builds a matrix of distances between queries and galleries and finally calculates Cumulative Matching Characteristic (CMC).

The whole pipeline

Before putting things together via SupervisedRunner, let us focus on two facts:

  • Since the designs of training and validation stages are different, we cannot call criterion on val_loader and also apply CMCScoreCallback on the train_loader. The required behavior is provided by ControlFlowCallback.
  • We define training epoch as a process by which we go through all the classes from the training dataset (instead of all the samples in the classical scenario). It makes these training epoch very quick. At the same time, validation epoch includes all the samples in the query/gallery dataset. As a result we should run the validation stage much less often than the training stage. PeriodicLoaderCallback is used for this.

To sum it up, the complete pipeline is presented below. It can also be found in Catalyst’s Readme as a minimal example called CV — MNIST with Metric Learning.

If you run this code, after 600 epoch the metric value will be approximately 0.97.

I hope this tutorial was useful to you. More details are available in the documentation and minimal examples. Do not hesitate to adopt this code to your task and ask any questions in our slack community I also especially thank Nikita, Julia, and Sergey for their help and advices during metric learning feature-release. See you in the next posts!

PyTorch

An open source machine learning framework that accelerates…

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store