Triplet Loss and Siamese Neural Networks

Enosh Shrestha
4 min readOct 24, 2019

--

In the previous blog post, we implemented a model to learn a similarity function that takes two images as inputs and outputs 1 if they belong to the same class and zero other wise. Another way to train a Siamese Neural Network (SNN) is using the triplet loss function.

Triplet Loss

It is a distance based loss function that operates on three inputs:

  1. anchor (a) is any arbitrary data point,
  2. positive (p) which is the same class as the anchor
  3. and negative (n) which is a different class from the anchor

Mathematically, it is defined as: L=max(d(a,p)−d(a,n)+margin,0).

We minimize this loss, which pushes d(a,p) to 0 and d(a,n) to be greater than d(a,p)+margin. This means that, after the training, the positive examples will be closer to the anchor while the negative examples will be farther from it. The image below shows the the effect of minimizing the loss.

Fig 1: Before (left) and after (right) minimizing triplet loss function

Triplet Mining

Based on the definition of the loss, there are three categories of triplets:

  • easy triplets: triplets which have a loss of 0, because d(a,p)+margin<d(a,n)
  • hard triplets: triplets where the negative is closer to the anchor than the positive, i.e. d(a,n)<d(a,p)
  • semi-hard triplets: triplets where the negative is not closer to the anchor than the positive, but which still have positive loss: d(a,p)<d(a,n)<d(a,p)+margin

Each of these definitions depend on where the negative is, relatively to the anchor and positive. We can therefore extend these three categories to the negatives: hard negatives, semi-hard negatives or easy negatives.

The figure below shows the three corresponding regions of the embedding space for the negative.

Fig 2: Regions of embedding space for negatives.

Triplet Mining for training

A model can be trained on triplets by using either offline or offline triplet mining.

Offline Triplet Mining: In this approach, we first generate the triplets manually and then fit the data to the network.

Online Triplet Mining: In this approach, we feed a batch of training data, generate triplets using all examples in the batch and calculate the loss on it. This approach allows us to randomize the triplets and increase the chance to find triplets with high losses — this will help train the model faster. For batch size of N, we can generate at most N ³ triplets.

Implementation

Implementation 1: Offline triplets - The implementation using this method is easy and straight-forward. We create triplets of data from an existing data set as in image below and then minimize the loss on it. My implementation can be found here.

Figure 1: Triplet inputs: anchor(left), positive(middle), negative (right)

Implementation 2: Online triplets- For this implementation, I fed the class of inputs as tensors along with the batch inputs. So the loss function needs to operate on one of the inner layers. To facilitate this, I created a first-class function i.e. a function that returns a function as below:

def get_loss_function(labels):    def loss(y_true, y_pred):        return batch_hard_triplet_loss(tf.squeeze(labels), y_pred , margin=10, squared=False)return loss

This implementation can be found here.

After training any of these model, we have a embedding space. So, I fit a kNN model to the embedding of the training data with k = 3. Then I use the kNN model to predict the class of the test data.

Visualization

The images below are of the embedding space of the training and test data. The model was trained on very little data (400 instances of same class). The test data has ~1000 instances. As expected from the triplet loss function, we see that there are neat clusters in the train data (left). These were generated using Tensorboard.

Embedding Space for train data (left) and test data (right) using PCA in Tensorboard

--

--