One Shot learning, Siamese networks and Triplet Loss with Keras
In modern Machine Learning era, Deep Convolution Neural Networks are a very powerful tool to work with images, for all kinds of task. We’ve seen some networks that are able to classify/detect about 1000 different kinds of object with very good performance. The traditional way of building a classifier is as follows:
In a 2015 paper, (called FaceNet), the authors took another approach. Instead of having a NN as a direct classifier, they built a “Siamese network” and compared the output to decide if the two inputs were similar, calling this approach “One Shot Learning” and using a specific loss function called the “Triplet Loss”
This article is about exploring these two concepts and applying them to the MNIST dataset using Keras
The One Shot Learning concept and Siamese Networks
In a traditional classification project, you typically train a neural network so that with a picture as input, the network outputs a probability (usually a softmax) for each class. So if you want to know whether the picture contains a cat, a dog or a horse, then your network will output 3 probabilities, one for each class. If you input a dog picture, the network is supposed to output a high probability for the dog output and low probability on the cat and horse output. Another way of seeing this is that the model is answering the question “Which class is the input ?”
For this kind of training, you will train your network by feeding vast quantities of images of cats dogs, and horses until the network learns how to properly qualify them. After thorough testing, you are happy with your model and you can deploy it somewhere to be of some use.
Now what if you need to add a monkey class to this animal classifier ? You need to find lots of monkey pictures, add a new output to your network, re-retrain it, re-test it and re deploy it.
This works perfectly in some situations, but in many cases, this could be unpractical. What if classes change too quickly to retrain ? A typical example for this situation is employee face recognition for company security gates. The classes here are the employees and you need to train your system with employee faces. Retraining, retesting and redeploying the model each time an employee arrives or leaves the company is not really realistic. And having thousands of pictures of each employee is not really realistic either.
Enter the 2015 paper “FaceNet: A Unified Embedding for Face Recognition and Clustering” (following this closely related one). The main idea is to be able to take a decision based on only one sample, thus be able to compare two images and tell if they are the same or not. Hence “One Shot Learning”
The question answered by our system becomes: “are these two pictures similar ?” We are not building a direct classifier, we are building a similarity comparer.
To do this, the Siamese network architecture is used: the input pictures go through the two networks but here theses two networks are actually the same. Same architecture, same weight, actually it’s the same network but used for two different inputs. The outputs are then used to decide if they are similar or not.
To analyse the similarity between two pictures, we need to be able to transform our input picture into a smaller representation, say a single vector. This representation is usually called an embedding. We need to build theses embeddings so that they have the following properties:
- Two similar images produce two embeddings so that the mathematical distance between them is small
- Two very different images produce two embeddings so that the mathematical distance between them is large
- The embedding is L2-normalized, ie each embedding is forced to be on the unit hypersphere
To do that, we need to train a specialized neural network for our context that produces good embeddings with these properties. For face comparison, the ‘two similar images’ we reference above could be the same face in two different photos (producing short distance embeddings) and ‘two very different images’ could be two different faces, producing large distance embeddings. For animals, “two similar images” could be the same species in different photos (short distance embeddings) vs photos of two different species (large distance embeddings), you got the point.
How long should your embedding be? Four numbers? 40? More? The idea is that your embedding must be able to contain a “good enough” representation of your class to differentiate it from the others. The embedding length is a new hyper-parameter of the problem
For example, to differentiate these four super-simple 3x3 pixel classes, two numbers are enough
In their FaceNET paper, the authors chose 64 numbers to represent human face characteristics. Trying with 128 improved performance slightly but not much, so they considered 64 as a good compromise.
Although not working with images, the BERT system built by Google used 768 numbers to represent a word or a phrase semantic meaning.
In this article, we chose to have embeddings of size 10.
Project Architecture for triple loss
There are two main ways of learning the parameters for our network. Firstly, we could simply consider the right part of the system as a binary classifier with Y=1 if the input pictures are from the same class and Y=0 if not.
Secondly, we could use the approach used in the FaceNet paper. Since we want to compare two pictures and have small distance embeddings for similar images and large distance embeddings for different images the idea is to consider
- A starting picture, called the Anchor
- A picture from the same class as the anchor, called the Positive
- A picture from a different class to the anchor, called the Negative
With these triplets of three images, (let’s call these embeddings A, P and N respectively), we want
distance(A,P) < distance(A,N)
Let’s rewrite it differently
distance(A,P) - distance(A,N) < 0
To prevent our network from learning an easy solution that satisfies the equation by outputting zeros for everything, let’s force a margin between AP and AN with a margin parameter:
distance(A,P) — distance(A,N) + margin < 0
Our Loss function will be based on this entity and will then have the following form:
Preparing data for training
In this article we will use the MNIST dataset. All pictures will be sorted according to their classes, this will be useful to prepare our triplet batches later
Now we need to prepare the triplet batches for training. Every sample of our batch will contain 3 pictures : The Anchor, a Positive and a Negative. How to choose our triplet ? A first approach would be to choose the triplet completely randomly like this:
We need our similarity network to learn to clearly differentiate our classes. Based on the definition of the loss, there are three categories of triplets:
- easy triplets: triplets which have a loss of 0, because d(A,P)+margin<d(A,N)
- hard triplets: triplets where the negative is closer to the anchor than the positive, i.e. d(A,N)<d(A,P)
- semi-hard triplets: triplets where the negative is not closer to the anchor than the positive, but which still have positive loss:
Each of these situations depends on where the negative is compared to the anchor and positive. It is fair to say that if our loss function gives 0 for these easy triplets, then our model will not learn much from them. On the other hand, the hard triplets will generate high loss and have big impacts on our network parameters. This will give any mis-labelled data too much weight. So we must choose a strategy that mixes easy, hard and maybe semi-hard triplets. In their original paper, FaceNet chose to draw a first, relatively large random sample of triplets (making sure the class distribution was somehow respected) and picked N/2 hard and N/2 random samples for their batch of size N. For this article we will make batches of 32 triplets, made of 16 hard and 16 random taken from a big batch of 200 random triplets.
Note that there are many strategies that can be used here, each having an impact on the training speed. With well-chosen triplets, the network converges faster, but using too many computations to choose which triplets to use could be slower overall. To go further, see this excellent article that discusses triplet mining.
Building our Keras model
The Neural Network
So, which architecture for our Neural Network? Since in this case we are focusing more on the loss function, the NN architecture here is not very relevant. Really, we could take any architecture, from a simple one to a powerful one such as ResNet. For the sake of simplicity, let’s take a simple one: three stacks of Convolutional /Pool layers+one fully connected. We must make sure that the final layer doesn’t have any activation function in order to have the full range of embedding values.
The triplet loss function, implemented as a custom Keras layer
Now our Full Keras model, with the anchor, the positive and the negative picture as inputs
Evaluation and metrics
In a traditional classifier, our performance would be based on the best prediction score of all the classes, leading to a confusion matrix with Precision, recall or F1 metrics.
In our situation, our model produces embeddings that we can use to compute distances, so we cannot apply the same system to evaluate our model performance. If the two pictures are from the same class, the distance should be “low”, if the pictures are from different classes, the distance should be “high”. So we need a threshold: if the found distance is under the threshold then it’s a “same” decision, if the distance is above the threshold then it’s a “different” decision. We have to choose this threshold carefully. Too low means we will have high precision but also too many false negatives. Too high and we will have too many false positives. This is a ROC curve problem. So, one metric for evaluation could be Area Under the Curve (AUC). In this article, we will choose the threshold so that the False Positive Rate is under 10e-3 and then evaluate sensitivy (recall).
Our evaluation process will take the test dataset and evaluate the distance of all pictures against each other, and then compute the AUC using the sklearn function.
Another interesting metric to look at during training is how “far” the embeddings from each class are from each other. To be thorough, we should evaluate the whole dataset, but here this is just to check that the network is converging smoothly for all the classes.
Now let’ train our model. Each loop builds a new batch and processes it.
Here’s our ROC curve after the training :
Now let’s see how far the distances are between classes. They moved from an average of 0.5 to 1.4. This clearly indicates that their respective embeddings are better produced by our network
Let’s play with our network by trying to compare few tests images against a reference picture from each class. We can clearly see that our network is now a solid similarity comparer:
- the distance between the test image vs another image of the same class is low (10e-2)
- the distance between the test image vs another class is high (10e-0)
- One shot learning is another approach to classification. It can be used if the number of “classes” changes too often and/or there is not enough data per class
- It can be used with a lot of different neural network architectures.
- The triplet Loss technique is one way of training the network. It requires a strategy to choose goods triplets to feed the network during training.
I hope this helped you in understanding “one shot learning” methodology using Deep Learning, I found it very interesting.
The full code is available here on my github.
This code was build on Azure Notebooks, running on a Azure Data Science Virtual Machine (with Nvidia K80 GPU)
References and thanks
Andrew Ng course on Deep Learning Convolutional Neural Network on Coursera
2015 FaceNet: A Unified Embedding for Face Recognition and Clustering by Florian Schroff, Dmitry Kalenichenko and James Philbin
Another detailed explanation of triplet loss and triplet mining by Olivier Moindreau with TensorFlow
Tess Ferrandez for her better Loss implementation than mine