Self-Supervised Learning: Blending Essence of Biological Intelligence into Machines

Rabia Eda Yılmaz
albert-health
Published in
11 min readAug 29, 2023
Photo by Pierre Bamin on Unsplash

In this post, we will go through — so called by Facebook, the dark matter of intelligence, Self-Supervised Learning (SSL) paradigm. Firstly, a quick background will be provided. In the next section, SSL will be deep dived and, in the last section, the commonly used SSL models will be described.

1.Background 💭

Energy-Based Model (EBM)

An energy-based model (EBM) is a trainable system that, given two inputs, x and y, tells us how incompatible they are with each other [1]. For example, x could be a short video clip and y another proposed video clip. The machine produces a single number which is called an energy. It indicates the incompatibility between x and y. If the energy is low, x and y are deemed compatible; if it is high, they are deemed incompatible.

Also, there are Siamese networks or joint embedding architecture. It is composed of two identical or almost identical copies of the same network. One network is fed with x and the other with y. The networks produce output vectors that represent x and y and called embeddings. There is also a third module, joining the networks at the head. It computes the energy as the distance between two embedding vectors. So, when the model is shown distorted versions of the same image, the parameters of the networks can easily be adjusted so that their outputs move closer together. This will ensure that the network will produce nearly identical representations or embedding of an object.

Joint Embedding Architecture. C produces a scalar energy that measures the distance between the representation vectors/embeddings produced by two identical twin networks sharing same parameters w [1]

However, there is a phenomenon to be aware of that is collapse. It is difficult to make sure that the networks produce high energy/different embedding vectors, when x and y are different images. The two networks can ignore their inputs and always produce identical output embeddings, which is called collapse. When collapse occurs, the energy is not higher for non-matching x and y than it is for matching x and y.

To avoid collapse, there are two categories of techniques: contrastive methods and regularization methods.

Contrastive Energy-based SSL

The idea of contrastive methods are based on constructing pairs of x and y that are not compatible and adjusting the parameters of the model so that the corresponding output energy is large. Training EBM with a contrastive method consists in simultaneously pushing down on the energy of compatible (x, y) pairs from the training set and pushing up the energy of (x, y) pairs that are incompatible [1].

In NLP, the method used to train by masking or substituting some input words belongs to the category of contrastive methods. However, they do not use the joint embedding architecture. Instead, they use a predictive architecture in which the model directly produces a prediction for y [1]. Firstly, a complete segment of text y is corrupted by masking some words to produce the observation x. Then, the corrupted input is fed to a large neural network that is trained to reproduce the original text y. An uncorrupted text will be reconstructed as itself — low reconstruction error, while a corrupted text will be reconstructed as an uncorrupted version of itself — large reconstruction error. If reconstruction error is interpreted as an energy, we can say that: low energy is for clean text and higher energy is for corrupted text.

The general technique of training a model to restore a corrupted version of an input is called denoising auto-encoder. But we can not use this trick for images because it is not feasible to enumerate all possible images. There is no solution to this problem. There are interesting ideas in this direction, but they have not yet led to results that are as good as joint embedding architectures.

One interesting avenue is latent-variable predictive architectures.

A latent-variable predictive architecture. Given an observation x, the model must be able to produce a set of multiple compatible predictions, shown as S-ribbon. As latent variable z varies within a set, the output varies over the set of plausible predictions [1].

This type of model contain an extra input variable, z. It is called latent because its value is never observed. As the latent variable varies over a given set, the output prediction varies over the set of plausible predictions compatible with the input x. Latent-variable models can be trained with contrastive methods, like generative adversarial network (GAN). The critic/discriminator is computing energy indicating whether the input y is reasonable. The generator network is trained to produce contrastive samples to which the critic is trained to associate high energy.

But there is a major issue, contrastive methods are very inefficient to train. In high-dimensional spaces such as images, there are many ways one image can be different from another. Finding a set of contrastive images that cover all the ways they can differ from a given image is a nearly impossible to achieve.

Non-contrastive Energy-based SSL

This is the hottest topic in SSL and seems very promising. These methods for joint-embedding use various tricks, such as computing virtual target embeddings for groups of similar images or making the two joint embedding architectures slightly different through the architecture or the parameter vector. Examples include Variational Auto-Encoder (VAE) and sparse modeling. SSL for vision is advancing with the models like SEER, SwAV, ConvNet.

2.Self-Supervised Learning (SSL) 🎬

Labeled data is scarce. Labeled data in a specialized area is even more. With enough labels, a task can be solved very well by using supervised learning that is a bottleneck in terms of building machine learning models without using massive amounts of labeled data. However, good performance usually requires a massive amount of labels which is expensive and laborious. Exploiting that amount of unlabelled data with unsupervised learning is not easy and usually underperforms compared to supervised learning. Empirically, it is impossible to label everything in the world. Producing clean labels is expensive but unlabeled data is being generated all the time. Yet more, there are some tasks that do not have enough labeled data. Furthermore, it facilitates the difficulty of designing effective network architectures for specific tasks.

Humans do have a common sense which refers to our generalized knowledge about the world. For instance, if you are sitting while reading this article, we do not expect from you to float away from the ground.

Common sense is the essence of biological intelligence in both humans and animals. It helps to learn new skills without requiring massive amounts of teaching for every single task. Simply, humans rely on previously acquired background knowledge of how the world works. By this advantage, for instance, once we see a picture of a cat we will be able to recognize cats, even the ones that we have never seen.

Self-supervised learning (SSL) is one of the most promising ways to build such background knowledge and approximate common sense in intelligent systems [1]. It enables to learn important patterns to recognize and understand the data. Systems that are trained with SSL perform considerably better than supervised manner. SSL aims at learning latent representations from large-scale data by solving designed pretext tasks, rather than using human annotations.

So basically, by using SSL, we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner which is achieved by framing a supervised learning task in a special form to predict only a subset of information using the rest. By doing so, all the required information is obtained for both inputs and labels.

Summary of SSL — LeCun [2]

SSL obtains supervisory signals that leverage underlying structure from the data. The general technique is to predict any unobserved or hidden part of the input from the any observed or unhidden part of the input. For example, in NLP we can hide a part of the sentence and predict the hidden words from the remaining words. Also, we can predict past or future frames in a video (hidden data) from current ones (observed data).

Unsupervised learning is ill-defined and refers to no supervision. However, self-supervised refers to using far more feedback signals than unsupervised and reinforcement learning do [1].

Self-supervised learning allows us to train NLP models such as BERT, RoBERTa, XLM-R etc. These models are trained on large unlabeled text data sets and then for downstream tasks. Basically, they are pretrained phase in which some of the words have been masked or replaced to predict those words. By doing so, this system allows the model to represent the meaning of the text.

However, it is significantly more difficult to represent uncertainty in the prediction for images than it is for words [3]. Because we do not know how to efficiently represent uncertainty when we predict missing frames in a video or missing patches in an image. However, new SSL models/systems are being developed to overcome this issue.

SSL frameworks can be categorized into two classes: predictive and contrastive. Predictive SSL frameworks are auto-encoding, siamese network, and clustering.

SSL frameworks. a-c are predictive and d is contrastive SSL [3].

Predictive Models

Predictive models optimize the similarity or correlations between the representations of two views of the same object, without considering their similarity to that of negative samples in training objectives.

Auto-encoding

It is based on the use of auto-encoders. A standard auto-encoder learns a compressed latent embedding that represents the input of the encoder and expects to reconstruct the original input from the latent representation, i. e., the decoder output. The dimensionality of the latent representation must be carefully designed, as it determines the representation reliability. When setting a too large latent dimensionality, an auto-encoder risks to learn an identity function, i. e., maps the input directly to the output, and hence becomes useless. There are techniques to prevent auto-encoders from learning an identity function because input data can be partially corrupted by randomly zeroing some input values while trying to recover the original undistorted input.

Continuous Bag-of Words (CBoW) and Skip-gram are used to learn underlying word representations. CBoW is trained to predict a single word from its context words and conversely, Skip-gram aims to predict the left and right context words. CBoW is better at capturing syntactic relationships and Skip-gram is better at capturing semantic relationships.

Two architecture of Word2Vec [3]

Also, auto-regressive models can be used to learn representations by predicting the future on the past context. Auto-regressive Predictive Coding (APC) codes on wave samples by taking into account uni-directional information of a sequence and an additional context network aggregates the resulting representations up to the current time step. So, it is usually like a recurrent neural network (RNN) for modeling the temporal information. The output context vector is used to predict the next audio representation.

Masked Predictive Coding (MPC) trains directly a bidirectional architecture by masking parts of the input signals. Transformer encoders and bidirectional RNNs have been considered as context networks. Further, Non-autoregressive Predictive Coding (NPC) applies a mask on its model input, but it learns representations based on local dependencies of an input sequence, rather than globally.

Auto-regressive Predictive Coding (APC) and Masked Predictive Coding (MPC) [3]

Siamese Models

They have typical two tower architecture in which the towers share same or similar architecture — parameters can be shared or be different. Each tower processes a view of data sample. The encoded representations in the high-dimensional latent space should be close to each other, thus, during training the representations from one tower can be seen as the training target, pseudo-labels, for the other tower.

Predictive Models Using Siamese Models [3]

Clustering

Each category occupy a separate manifold in the representation space. K-means, DeepCluster, Local Aggregation, SwAV etc. can be used for clustering in SSL.

Contrastive Models

The key idea is to pull the representations of two similar inputs — positive pair — close in the latent space and to push dissimilar ones. Siamese models that employ contrastive learning are SimCLR, MoCo, etc. Some models do not use negative samples, as in BYOL, SimSiam, Barlow Twins.

3.SSL in the Wild 🦁

Image-based SSL

There are many ideas on this topic. A common workflow is to train a model on one or multiple pretext tasks with unlabelled images and then use one intermediate feature layer of this model to feed a multinomial logistic regression classifier on data. The final classification accuracy quantifies how good the learned representation is.

  • Distortion: A small distortion on an image does not modify its original semantic meaning or geometric forms. Slightly distorted images are considered the same as original. So, the learned features are expected to be invariant to distortion.
  • Patches: The model extracts multiple patches from one image and ask the model to predict the relationship between these patches. Relative position is an important concept to understand the spatial context of objects.
  • Colorization: It can be used as a powerful SSL task by training a model to color a grayscale input image which is precisely mapping this image to a distribution over quantized color value outputs.
  • Generative Modeling: The pretext task is to reconstruct the original input while learning meaningful latent representation. Denoising Autoencoder learns to recover an image from a version that is partially corrupted or has random noise. Context Encoder is trained to fill in a missing piece in the image.
  • Bidirectional GANs: They introduce an additional encoder to learn the mappings from the input to the latent variable z. The discriminator predicts in the joint space of the input data and latent representation.
  • Contrastive Learning: Contrastive Predictive Coding (CPC) is an approach for unsupervised learning from high-dimensional data by translating a generative modeling problem to classification problem. The contrastive loss or InfoNCE is inspired from Noise Contrastive Estimation (NCE). It uses cross-entropy loss to measure how well the model can classify the future representation amongst a set of unrelated negative samples.

Audio-based SSL

The most of aforementioned SSL methods are transferred into audio domain. Speech2Vec is built based on an RNN encoder-decoder by processing mel-spectrogram of the audio, and Audio2Vec is built of stacks of CNN blocks by processing mel-frequency cepstral coefficients (MFCC). HuBERT, Wav2Vec, CPC models utilizes raw waveform input format. Usually contrastive or InfoNCE loss are employed. For the encoder part, 1D CNN, Transformer, BERT, LSTM, CNN etc. components and combinations are selected.

Wav2Vec Architectures [3]

Conclusion

In this post, we went through the core idea of SSL and similarity with common sense —background of human knowledge. This idea is implemented via SSL models. SSL paradigm is where a model learns to represent and understand data by generating its own labels/targets (pseudo-labels) from the input data itself, unlike the supervised learning that relies on externally provided labels. There are two main approaches in SSL: predictive models and contrastive models. Predictive models involve training a neural network to predict a part of the input data from another part of the same data and the idea is to create a task where the model to learn meaningful representations from the input. Contrastive models focus on learning representations by contrasting similar and dissimilar examples, so that, similar data points get closer in the representation space while pushing dissimilar data points farther apart.

For further research papers about self-supervised learning, this awesome Github repo is an awesome start: https://github.com/jason718/awesome-self-supervised-learning

References

[1] https://ai.meta.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/

[2] https://www.youtube.com/watch?v=7I0Qt7GALVk

[3] Liu, S., Mallol-Ragolta, A., Parada-Cabaleiro, E., Qian, K., Jing, X., Kathan, A., … & Schuller, B. W. (2022). Audio self-supervised learning: A survey. Patterns, 3(12).

--

--