Practical Metric Learning

Aleksey Shabanov
9 min readNov 22, 2022

This post is related to the recent release of a new open-source project called OpenMetricLearning (OML), and one of its goals is to lower the entry threshold for metric learning pipelines. We will briefly introduce the theory, discuss the examples in code and show how simple heuristics can perform on a level comparable with the current SotA. Since the project is new, each star on GitHub is essential for us.

About the Metric Learning problem

The goal of a metric learning pipeline (ML) is to build a function that takes 2 objects and estimates the distance (or similarity) between them. Having such a function we can perform clusterisation, search, anomalies detection and so on. The usage of deep neural networks brings us to deep metric learning and its 2 main approaches:

  1. Siamese. A neural network of this type has 2 inputs for each of 2 objects and returns the probability those objects are the same or similar (depending on the task setup).
  2. Representation learning. The neural network takes 1 object as input and returns a vector representing this object in some vector space. After that, we calculate classical distances among those vectors, for example, the cosine one.

Imagine, we need to estimate all the pairwise distances among n objects. The first approach requires O(n^2) model’s forward passes, but the second one requires O(n) passes and O(n^2) calculations of the distances. Since distance calculation is much faster than inference, the second approach is more prevalent in practice.

Here is a vector space of a model trained on Fashion MNIST:

Thanks to Zalando Research for this GFF. Note, you probably need to refresh a page to see the animation.

Metric Learning vs Classification

Both tasks may look pretty similar which makes terminology confusing. On the one hand, you can train a classifier to use its features from the penultimate layer as vector representations for further distance estimation. On the other hand, you can train a model with a non-classification loss function (like triplet loss, more details below), but use the obtained features for the classification via taking the most frequent label from k nearest neighbours (kNN).

Despite the arguments above, we still can highlight the differences:

  • Classification requires having the same classes in training and test sets, but for ML, it’s not needed.
  • ML can be used without class-level labels. For example, it’s able to deal with pairs of objects labelled by humans as positives and negatives. What is positive and what is negative depends on the task definition.

If there are benchmarks for Metric Learning?

Yes, as for classification researchers compare their results on a set of popular datasets. Here is a leaderboard for image-based datasets.

How to train and validate a model?

Let’s consider the DeepFashion dataset from the leaderboard above. It consists of 17 categories of clothes (jackets, jeans, shorts and so on) and 8k classes (ids of particular items). The median size of the class is 5 images.

The classes are split into train and validation sets without intersection. Note, that split is performed on a class level. The validation idea is to simulate the search procedure, thus it’s split into a query (search requests) and a gallery (search index). Note, here we split on an image level. For example, there is a jacket id 001 with 7 images, 3 are in query and the rest 4 are in the gallery. Our goal is to train a model in a way it returns these 4 gallery images as the most relevant results to the given 3 queries.

Let’s take a look at how we train a model with the classical triplet loss. Feel free to skip the section below if you are familiar with this loss.

A brief intro to triplet loss for those who are not familiar with it.

where:
(a, p, n) — triplet consists of an anchor, positive (has the same id as the anchor) and negative (has the id different from the anchor’s one) objects. Depending on the task, we may require visual similarity or introduce another logic instead of requiring equality in ids;
dist(a,p) — a distance between the anchor and positive objects that we want to minimise;
dist(a,n) — a distance between the anchor and negative objects that we want to maximise;
m — margin parameter.

There are other variations of triplet loss that sometimes converge better, for example, here is the “soft” version of triplet loss:

Anchor, Positive, Negative

Training pipeline

  1. Sampler creates a batch under the condition it has at least 2 classes having 2 images (otherwise we are not able to obtain triplets). Usually, batches are class-balanced.
  2. The batch goes into a model and turns into a batch of vectors.
  3. Miner collects triplets from the batch of vectors. It can collect all the possible triplets or only the hardest ones (when the negative distance is minimal and the positive distance is maximal). There are many other techniques available, for example, the utilisation of cross-batch memory to expand the efficient batch size.
  4. Optimizer does a step with respect to the triplet loss gradient calculated for the triplets above.

The scheme of the training pipeline may change, for instance:

  • You don't need a miner if you work with a classification loss (like Log loss or ArcFace).
  • If you have no class labels, but your data is labelled on the level of triplets or pairs, you pass it directly to a loss function, without the usage of a miner.
  • Miner collects quadruplets to work with quadruplet loss and pairs to work with contrastive loss.

Validation pipeline

  1. Conduct inference on the whole validation set and accumulate obtained embeddings.
  2. Calculate distances among all the possible query-gallery pairs which gives a matrix with the size of Q x G. Then we sort it by rows, to put the items from the gallery with the smallest distances to the top of the search result.
  3. Calculate metrics. The natural choice here is to pick metrics from the information retrieval world:
    CMC@k equal to 1 if there is at least one correct answer among the first k results, 0 otherwise.
    Precision@k is a fraction of the correct answers among the first k results.
    MAP@k is similar to the previous one, but also takes into consideration the positions of the correct results.

Let’s take a look at the example when we have 3 queries (the blue ones), for which we return 5 images in the order of increasing distances. Some of them have the same item id (highlighted with green as correct answers), but some of them not (highlight with red as mistakes). For all of them, CMC@5 equals 1.
As for Precision@5, it’s a bit harder since we need to know the number of all possible correct answers related to a given query in order not to decrease metric when there is no possibility to return k correct results. For instance, let’s say the 1st query has 5 images with the same id in the gallery, the 2nd one has 3, and the 3d one has 4. Thus, metric values are the following:

About OpenMetricLearning

OML is a new library for representation learning based on PyTorch. We listed simple code examples below for a better understanding of the library. Most likely in practice you may be more interested in the pipeline integrated with PyTorch Lightning or Config API (more details below), but under the hood, they follow the same logic.

Code for training:

import torch
from tqdm import tqdm

from oml.datasets.base import DatasetWithLabels
from oml.losses.triplet import TripletLossWithMiner
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models.vit.vit import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset

# download dummy dataset
dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)

# create a model based on checkpoint pretrained is a self-supervised way
model = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)


train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
# create criterion consists of loss function and miner
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
# create sample that puts 2 samples for 2 classes into the batch
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)

# rest of the code doesn't differ from a normal PyTroch pipeline
for batch in tqdm(train_loader):
embeddings = model(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()

Code for validation:

import torch
from tqdm import tqdm

from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models.vit.vit import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset

# download dummy dataset
dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)

# create a model based on checkpoint pretrained is a self-supervised waymodel = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)

# create metrics calculator that handles vectors accumulation and metrics compute
calculator = EmbeddingMetrics()
calculator.setup(num_samples=len(val_dataset))

with torch.no_grad():
for batch in tqdm(val_loader):
batch["embeddings"] = model(batch["input_tensors"])
calculator.update_data(batch) # accumulating vectors

# compute metrics: cmc@k, precision@k, map@k
metrics = calculator.compute_metrics()

OpenMetricLearning vs PyTorchMetricLearning

Everything is relative, so, to give you a better understanding of OML we will compare it with PyTorch Metric Learning. PML is the popular library for metric learning, and it includes a rich collection of losses, miners, distances, and reducers; that is why we provide straightforward examples of using them with OML. Initially, we tried to use PML for our needs, but in the end, we came up with our library, which is more pipeline/recipes oriented. That is how OML differs from PML:

  • OML has Config API which allows training models by preparing a config and your data in the required format (it’s like converting data into COCO format to train a detector from mmdetection).
  • OML focuses on end-to-end pipelines and practical use cases. It has config-based examples on popular benchmarks close to real life (like photos of products of thousands of ids). We found some good combinations of hyperparameters on these datasets, trained and published models and their configs. Thus, it makes OML more recipes oriented than PML, and its author confirms this saying that his library is a set of tools rather the recipes, moreover, the examples in PML are mostly for CIFAR and MNIST datasets.
  • OML has the Zoo of pretrained models that can be easily accessed from the code in the same way as in torchvision (when you type resnet50(pretrained=True)).
  • OML is integrated with PyTorch Lightning, so, we can use the power of its Trainer. This is especially helpful when we work with DDP, so, you compare our DDP example and the PMLs one. By the way, PML also has Trainers, but it’s not in the examples and custom train / test functions are used instead.

We believe that having Config API, laconic examples, and Zoo of pretrained models sets the entry threshold to a really low value.

How accurate may be a model trained with OpenMetricLearning?

It may be comparable with the current SotA methods. Let’s consider Hyp-ViT, which is ViT architecture trained with contrastive loss, but the embeddings were projected into some hyperbolic space. As the authors claimed, such a space is able to describe the nested structure of real-world data. So, the paper requires some heavy math to adapt the usual operations for the hyperbolical space.

We trained the same architecture with triplet loss, fixing the rest of the parameters: training and test transformations, image size, and optimizer. The trick was in heuristics in our miner and sampler:

  • Sampler forms the batches limiting the number of categories C in it. For instance, when C = 1 it puts only jackets in one batch and only jeans into another one. It automatically makes the negative pairs harder: it’s more meaningful for a model to realise why two jackets are different than to understand the same about a jacket and a t-shirt.
  • Miner makes the task even harder keeping only the hardest triplets (with maximal positive and minimal negative distances).
Comparison of Hyp-ViT and our heuristic by CMC@1. For our experiments configs are available here and here.

Thus, we perform on the SotA level utilising simple heuristics and avoiding heavy math.

UPD: After the detailed ablation study we realised that even without category aware sampling our model is able to perform on the same level. In other words, having a simple hard mining mechanism is enough to be comparable with SotA models.

Summary

If you want to dive deeper into the discussed type of machine learning problems, you are welcome to contribute to OpenMetricLearning. You can take one of the existing tasks (we have both scientific and engineering ones) or suggest your own idea by submitting a new issue.
We will be grateful for the stars on GitHub & claps!

--

--