How to train Siamese Network on COVID-19 X-ray images

Yiwen Lai
Analytics Vidhya
Published in
8 min readApr 28, 2020

--

Photo by MJ Tangonan on Unsplash

In this post, I am not solving any problem related to COVID-19 but explore concepts and implementation of SiameseNet on X-ray images. If you wish to look at code you can click on the link below.

Meta-Learning

Meta-Learning is one of the most promising fields in artificial intelligence. Some schools of thought in AI community believe that meta-learning is a stepping stone towards unlocking artificial general intelligence (AGI). The idea behind this technique was to create a process with the concept of learning to learn.

In order to appreciate why this meta-learning is an important milestone, we can look at how deep learning classification works. Imagine we need to build a dog and cat classifier, we provide hundreds of images of dog and cat for training and we obtain this trained model. When a hamster image is given to our classifier, the model will fail miserably. Our model will predict it as a dog or a cat, even a 5-year-old kid would recognise it as a new class of pet.

So how do we human learn? We are able to learn by looking at a few images and will easily identify hamster as a new species outside of cats or dogs. But in order for our model to predict correctly, we will need to provide hundreds of images of hamster and retrain our model in order to work.

When Corona doesn't care that you are just a hamster.

What if these images are hard to come by, for example, medical images where a positive case of a certain illness is usually much lower than a negative case (healthy patient). Meta-learning provides a solution to these problems to create a more general model, without re-training to detect a new class and only requires a few images to train. This is much closer to how human learn compare to the standard image classification.

Different types of meta-learning

Meta-learning can be mainly characterised into 3 categories, learning of metric space, learning the initialization and learning the optimizer.

Learning the metric space

Learning the metric space simply means having the neural network learn to extract the features from the inputs and placing them in a higher dimension vector. Let’s say we want to learn to identify images of 2 different classes. We use a neural network model that extract features from these images and compute the similarity distance between these classes. At the end of the training, we would want similar classes to be close together and different classes to be far apart. There are many metric-based learning algorithms one of such algorithm is called Siamese Network which be explain with more detail later. Other such algorithms are Prototypical networks and Matching networks, they will not be covered in this post but I will provide some reference if you wish to explore further.

Learning the initialization

For this method, the approach is to learn the optimal initial parameters or weights for the model. Instead of using random weights when initialize, we use the optimal parameters to start of training. With this, we will be able to converge faster and require less data when training. This is using the similar concept of transfer learning, where the objective is to use some pre-obtained knowledge to aid us on a new task. There are many initialization algorithms such as MAML, Reptile and currently gaining in popularity self-supervise learning.

Learning the optimizer

For this method, the algorithm will try to learn the optimizer function itself. In a sample, imagine we are training a neural network by computing loss through gradient decent, we want to minimise this loss. Usually, we would use SGD or Adam as an optimizer. Instead of using these optimizers what if we could learn this optimization process instead. One of the methods is to replace our traditional optimizer with a Recurrent Neural Network. If you like to learn more please refer to the link provided below.

Sorry for the lengthy introduction, because I find it quite important to share. As one of the next big turning points in AI is to generalise faster from fewer examples. So let's jump into implementation on SiameseNet.

If you are interested in future development on AI, you can read about it on the link below.

Summary

  1. Train / Test split
  2. Create a custom Generator
  3. Visualising images from Generator
  4. SiameseNet architecture
  5. N-way Evaluation
  6. How do we use SiameseNet

1. Train / Test split

Train set 10 images and rest for Test set

For this project, there will be a total of 3 classes and we will only be using 10 samples from each class. The 3 classes will be consist of X-ray images of COVID-19, Bacterial and Normal patients. If you read other’s post about one-shot learning or multi-shot learning. They all have similar concepts, using a small fraction of data to train the model. For our case, this will be 3 classes 10-shot learning.

2. Create a custom Generator

Example of how image pair will look like

To train our network we need a custom generator that will be used for generating image pairs for training and validation. The generator will provide sets of correct pairs and incorrect pairs and their respective label. Above gives an example of how the pairs will look like, correct pairs will be given label 1 and incorrect pairs are label 0. The generator also ensures that the set of correct and wrong pairs are balanced. When we set batches to 30, it will produce the shape of (2, 30, 100, 100, 3), (30,) it can be read as follows:

(2 pairs, 30 images, 100 widths, 100 height, 3 channels), (30 labels)

3. Visualising images from Generator

The following are examples of what the model will see when training

In this project, I did not use image augmentation. I wanted to explore to see with only 10 images pre-class can the model produce a good result. (Spoiler: Yes, the result is pretty good)

4. Siamese Network

The intuition of Siamese network is to create twins model to extract features and compute the difference between the 2 images that were fed in. We want the model to learn the difference and create embeddings that are able to split into different clusters.

SiameseNet architecture use in this post

The 2 base models highlighted in blue are not different networks but are the same copy of each other and they share the same weights. We will be reusing the trained model use in the previous post as this will greatly improve the convergence rate compare on training on ImageNet weights. I have tested using ImageNet weights it will take about ~1600 epochs while using our pre-trained weights we can converge ~ 65 epochs. That is roughly about 24x times faster.

When 2 images are passed into our model as input, our model will generate 2 feature vectors (embedding). We then compute the difference between the features and use sigmoid to output a similarity score. During training, errors will be backpropagated to correct our model on mistakes it made when creating the feature vectors. This is how our SiameseNet learn from the pairs of images.

If you wish to look at how are based model is created please refer to my previous post.

5. N-way Evaluation

To evaluate our model we have to create a custom validation process. The test has to determine whether the model is able to recognize a similar image from different ones. One way to do this is N-way one-shot learning, it may sound difficult but is actually very simple.

The example shows 5-way one-shot learning. Note that the left-hand side images are all the same.

For this project, we test the model with N=20 (20-way one-shot learning). Imagine that we are creating an exam paper for the model, each question with a multiple choice of 20 where only one of the answer is correct. We give 100 questions for the model to answer and see how well the model performs. The model will give each pair a prediction, the one with the highest prediction is the one chosen by the model to be correct. You can imagine how tough this paper will be if given to a person with 20 multiple choices.

In the end, the model did quite well it is able to get 96.0% correct when N=20 and manage to get 88% when N=30. Which is awesome! Remember we only provide the model with 10 images per class.

When you realise you have 20 selections for MCQ.

6. How do we use SiameseNet

So what can we do with it after we trained a SiameseNet? When a new unseen COVID-19 X-ray image is given, we can use the model to create an image embedding. After that, we can use this embedding to measure similarity score against our 3 classes cluster. This new unseen COVID-19 image should be close to our COVID-19 cluster.

So what about the unknow class we talk about earlier on? Let’s say another X-ray image comes in with a new unknown infection. The embedding generated will be an outlier outside all of our 3 clusters. This will flag out the X-ray image as an unknown class and doctors should be brought in for further verification.

--

--

Yiwen Lai
Analytics Vidhya

🤖 AI² | NTU Computer Science Graduate | NUS M.Tech Knowledge Engineering | https://twitter.com/Niel_Lai