PyTorch Metric Learning: What’s New

Kevin Musgrave
5 min readSep 16, 2020

PyTorch Metric Learning has seen a lot of changes in the past few months. Here are the highlights.

Distances, Reducers, and Regularizers

Loss functions are now highly customizable with the introduction of distances, reducers, and regularizers.

Distances

Consider the TripletMarginLoss in its default form.

This loss function attempts to minimize:

where “d” represents L2 distance. But what if we want to use a different distance metric like unnormalized L1, or signal-to-noise ratio? With the distances module, you can try out these ideas easily:

You can also use similarity measures rather than distances, even though similarities are inversely related to distances:

With a similarity measure, the TripletMarginLoss internally swaps the anchor-positive and anchor-negative terms:

where “s” represents similarity.

Reducers

Losses are typically computed per element, pair, or triplet, and then are reduced to a single value by some operation, such as averaging. Many PyTorch loss functions accept a reduction parameter, which is usually “mean”, “sum”, or “none”.

In PyTorch Metric Learning, the reducer parameter serves a similar purpose, but instead takes in an object that performs the reduction. Here is an example of a ThresholdReducer being passed into a loss function:

This ThresholdReducer will discard losses that fall outside of the range (10, 30), and then return the average of the remaining losses.

Regularizers

It’s common to add embedding or weight regularization terms to the core metric learning loss. Thus, every loss function has an optional embedding regularizer parameter:

And classification losses have an optional weight regularizer parameter:

Flexible MoCo for Self-Supervised Learning

Momentum Contrastive Learning (MoCo) is a state-of-the-art self-supervision algorithm.

Figure 1 from the original paper

In a nutshell, it consists of the following steps:

  1. Initialize two convnets, Q and K, that have identical weights.
  2. At each iteration of training, set the weights of K to (m)*K + (1-m)*Q, where m is the momentum.
  3. Retrieve a batch of images, X, and a randomly augmented version, X`.
  4. Pass X into Q, and X` in K, and store K’s output in a large queue.
  5. Apply the InfoNCE loss (a.k.a NTXent), using [Q_out, K_out] as positive pairs, and [Q_out, queue] as negative pairs.
  6. Backpropagate and update Q.

This simple procedure works amazingly well for creating good feature extractors. You might be wondering if it’s possible to use a different loss function, distance metric, or reduction method. And what about mining hard negatives from the queue?

With this library, it’s very easy to try these ideas by using CrossBatchMemory. First, initialize it with any tuple-based loss, and optionally supply a miner:

Create “labels” to indicate which elements are positive pairs, and specify which part of the batch to add to the queue:

Compute the loss and step the optimizer. CrossBatchMemory takes care of all the mining, loss computation, and bookkeeping for the queue:

To confirm that CrossBatchMemory works with MoCo, I wrote a notebook demonstrating that it achieves accuracy equivalent to the official implementation on CIFAR10 (using InfoNCE and no mining). You can run the notebook on Google Colab.

AccuracyCalculator

If you need to compute accuracy based on k-nearest-neighbors and k-means clustering, AccuracyCalculator is a convenient tool for that. By default, it computes 5 standard accuracy metrics when you pass in query and reference embeddings:

Adding your own accuracy metrics is straightforward:

Now when you call “get_accuracy”, the returned dictionary will include “some_amazing_metric”. Check out the documentation for details on how this works.

Distributed Wrappers

To make losses and miners work in multiple processes, use the distributed wrappers:

Why are these wrappers necessary? Under the hood, metric losses and miners usually have to access all tuples in a batch. But if your program is running in separate processes, the loss/miner in each process doesn’t get to see the global batch, and thus, will see only a fraction of all tuples. Using the distributed wrappers fixes this problem. (Thanks to John Giorgi who figured out how to implement this in his project on constrastive unsupervised textual representations, DeCLUTR.)

Examples on Google Colab + Documentation

To see how this library works in actual training code, take a look at the example notebooks on Google Colab. There’s also a lot of documentation and an accompanying paper.

As a final note, here’s a long and narrow view of this library’s contents.

Hope you find it useful!

--

--