Understanding SimCLR — A Simple Framework for Contrastive Learning of Visual Representations with Code

Aditya Rastogi
Analytics Vidhya
Published in
8 min readApr 4, 2020

This paper [1] presents a simple framework (which the authors call SimCLR) for contrastive learning of visual representations. These visual representations are vectors on which linear classifiers can be trained to solve problems like image classification. We know that we can learn these visual representations by training deep learning models like ResNet on labeled datasets like ImageNet. But labeling and annotating data is a time-consuming process and requires a lot of labor, so we wish to avoid it as much as possible. Self-supervised learning is a learning technique where the training data is automatically labeled by finding and exploiting correlations between different input features. How can we learn these visual representations without human supervision? Contrastive learning is the answer which this paper suggests. Let’s look at what it is with the help of an example.

Contrastive Learning

Contrastive learning approaches, learn representations by contrasting positive pairs against negative pairs. Let’s understand what these positive and negative pairs are, through an example.

Suppose we have the following batch of 25 images with us. These are images from 5 categories namely — airplane, car, dog, elephant, and cat, but note that these labels won’t be used to train the entire model. However, we will use these labels for our visualization purposes while learning about the algorithm.

A batch of 25 example images from this dataset

Each of these 25 images is of size 224 x 224. We apply a composition of two data augmentation operations, the first being Random crop and Resize to 224 x 224 and the second being Color Distortion, two times for each image, to get two new images.

From left to right: original image, two images obtained by applying a composition of random crop & resize and color distortion, two times on the original image.

The code to apply these data augmentations in PyTorch is as follows.

We apply this composition of data augmentations on each of our 25 images to get 50 augmented images. We work with these 50 augmented images afterward. In these images, we define positive pairs as those pair of images which we got from the same original image. So, there are 25 positive pairs. In general, for a batch of N images, we would get 2N augmented images. Given a particular positive pair (i,j) from these 2N images, we consider the other 2(N-1) images as negative examples for i and j.

So, now that we know what positive and negative pairs are, let’s have a look at how the authors use these to learn visual representations.

Source : [1]

So, till now we have got these augmented images. We feed each positive pair in a neural network (the composition of f and g as shown in the above image) and we get vectors z_i and z_j. We maximize the agreement between these vectors. We want different outlooks of the same image to have similar representations.

The agreement between these vectors is maximized by minimizing the contrastive loss (normalized temperature-scaled cross-entropy loss or NT-Xent in short) between these vectors. Let's understand this loss in detail.

Normalized temperature-scaled cross-entropy loss

gif from giphy.com

Let sim represent the cosine similarity function as shown below.

cosine similarity

Then, NT-Xent loss for a positive pair of examples ( i, j ) is defined as

NT-Xent Loss function for a positive pair of examples ( i, j )

By simplifying the above loss function form using properties of the log function and some algebraic manipulation, we get a little bit more understandable form.

We see that by minimizing the NT-Xent loss for a positive pair of examples ( i, j ), we not only make the vector z_i more similar to z_j but also make it dissimilar to all the other vectors.

Note that the loss is asymmetric, i.e. l_{i, j} != l_{j, i} because the sum in the numerator would be different. So, in order to compute the final loss, we compute it across all positive pairs (with (i, j) and (j, i) taken separately) and then take the average.

We can see tau as a temperature hyperparameter which makes the loss function form more expressible. Various values of it can be tried and the authors have done it as well.

Below is the code for this loss function in PyTorch.

Understanding the above code (feel free to skim read the explanation if you understand the code after reading it):

In the above code, a and b are vectors containing N representations each, of the augmented images and the order is maintained in the sense that for i from 1 to N, a[i] and b[i] together is a positive pair. We see that we need similarity between pairs of vectors in the loss function. And we know that similarity between two vectors is the dot product between their respective unit vectors. So we find the unit vectors for each row in a and b and store the results in a_cap and b_cap. Note that we need similarity between all pairs chosen from 2N examples. So, we first concatenate a_cap and b_cap to get a single vector a_cap_b_cap of 2N examples. Now the most crucial part is to understand that the matrix product of a_cap_b_cap with a_cap_b_cap_transpose gives us the similarity matrix whose i,j-th entry represents the similarity between representation i and representation j where both i and j vary from 1 to 2N. I hope that the rest of the code is clear from the names of the intermediate variables.

Consider the following four positive pairs which are obtained from two different car images.

By maximizing agreement between representations of positive pairs, what we are essentially doing is we are making representations of cars and car parts (because of random cropping), with less emphasis on color (because of color jittering), close to each other. This leads to an increase in the similarity between representations of cars as a whole.

Results

We train a resnet using the NT-Xent loss function on our dataset. It contains 1250 images for train (250 per category) and 250 images for test (50 per category).

The code for resnet is as follows. We use a resnet18 model, and we replace the top layer by some other fully connected layers.

Resnet 18 with the last layer replaced with a non-linear classifier

We got the following graph of NT-Xent loss vs. the number of epochs while training the above resnet.

Plot of training loss vs. the number of epochs

We randomly select 10% examples from our training data for which we will reveal the labels in order to train a linear classifier on top of our learned representations. This can be thought of as follows: Suppose you’re seeing Devanagari numerals for the first time and you don’t know which numeral represents which digit. You can still identify ten groups in the set of numerals. So, in order to identify Devanagari numerals, all you need to know is which cluster corresponds to which digit. Hence, less amount of labeling is required as some sort of clustered representations for the numerals are already in your mind.

We visualize the last layer vectors (25 dimensional in our case) for our test (250 images) and train (only 10% — 125 images) using t-SNE. In the following scatter plot, we use the labels, to have the same color for images that belong to the same category.

t-SNE visualization of the last layer vectors

The following is the t-SNE visualization of second last layer vectors (50 dimensional in our case).

t-SNE visualization of the second last layer vectors

We see some clusters in the above images, which is a good indication.

The code to obtain the above t-SNE visualizations is as follows.

We now train a linear classifier on the representations obtained from the second last layer (50 dimensional). The linear classifier’s code is as follows.

We obtain the following plots of training and testing accuracy & cross-entropy losses vs. the number of epochs.

Plots of accuracy and loss vs. the number of epochs obtained while training a linear classifier on 10% labeled training data

Comparison with a Supervised classifier

In order to understand how good the obtained results are, we compare them with results obtained from a supervised classifier (a Resnet-18 model). Table 1 shows the comparison. The accuracy and loss plots for the supervised classifier are as follows.

Plots of accuracy and loss vs. the number of epochs

The sudden increase in accuracy around 180 epochs is when we turned off data augmentation, after which the model started overfitting. We stored the model with the best test accuracy.

We see a 14% accuracy gap between the supervised and unsupervised classifiers. We used a batch-size of 256 and the Resnet-18 model. This 14% gap can be further reduced if we use larger batch size and bigger and wider models because SimCLR benefits more from these as compared to supervised methods. But that comes with a requirement of greater computational power (momentum contrast with ideas from SimCLR (MoCo-V2) can help us with this — I wrote a blog post about it here).

The code for this article is in this GitHub repo.

References

[1] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey E. Hinton. A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709, 2020 ; [highlighted by me]

Updates:

Thank you for your time in reading this article.

--

--

Aditya Rastogi
Analytics Vidhya

Interested in learning about computations that make perception, reasoning and action possible.