Multimodality: attention is all you need

Aleph Alpha
Aleph Alpha Blog
Published in
11 min readFeb 14, 2022

AI is the buzzword of the current age, and unlike many buzzwords, not fully without merit. Recent AI advances have rapidly pushed the limits of what is possible, but our systems remain narrowly specialized on one or at most a few closely related tasks. The illusive goal of “Artificial General Intelligence” (AGI) promises systems that learn in a far more general way, and can perform a wide range of tasks in many domains without custom-fitting.

A key to enabling the dream of AGI is multimodality, training an AI system not just on a single input type (“modality”) such as vision, speech or text, but on many simultaneously, so it can translate what it has learned in one domain to others. Multimodality promises groundbreaking advances in general intelligence applications, let’s explore!

When training our AI models, what we’re trying to build is a model of reality that captures the properties necessary to perform whatever task we’re trying to do. For example, if we want to classify dogs, we need to learn a model of “dog-ness” that can recognize the relevant properties of images that contain dogs. We can imagine there being some underlying Reality that we want to model in order to answer questions (such as “Is this a dog?”, “What will happen next?”, “What should I do next?” etc). We do not get direct access to the true underlying state of Reality (and probably couldn’t do much with it even if we did have access, since it would be far too large to deal with), instead, we perceive reality through a number of Senses. A Sense is some kind of preprocessed slice of the true state of Reality that we have access to.

Reality (Complex Quantum State) -> Visual Preprocessing Sense (Taking in certain subsets of photonic information and outputting an image in RGB space) -> Input into our models

Each of these Senses represents a modality. The information contained in the output of the Sense is not the full state of Reality, some data is discarded in the encoding to make the input manageable. We now have to figure out a way to go from our Input to a solution to the problem we’re trying to solve. This is generally accomplished by training a neural network for the task. Such a NN will then (generally) learn new, higher level concepts that it can extract and operate on from its input data.

In this visualization of an object classification NN (from https://distill.pub/2020/circuits/early-vision/), we can see how several neurons are used to build a circle recognizing circuit. Circles are useful for classifying the kinds of objects we care about, so it makes sense to learn about circles.

The key to the success of our NN is to extract the right features from its Input and then operate on these (this is sometimes called Feature or Representation Learning). This applies just as much to image recognition as it does to playing chess or any other domain. Chess players also naturally recognize different features of board positions and use these to plan their strategies. Text can be seen as a special kind of Sense, a “second-order Sense”. Text is a downstream product of human intelligence, which itself is downstream of the primary senses encoding reality. This explains why, despite having no “direct” contact with Reality, models such as GPT3 can pick up significant amounts of real world knowledge.

Text as a second order sense.

Learning such useful features, and strategies utilizing them, is the fundamental problem our training is trying to solve. There are (at least) three reasons why our model might fail to learn a useful property:

  • We do not have enough data to learn the property from (or we have the wrong kind of data)
  • Our model lacks computational or algorithmic power to learn the property
  • The preprocessing of the Senses make learning the property hard or impossible

Multimodality clearly adds complexity to the model, but it provides many benefits:

  • We have access to larger pools of data since we are not limited just to one type of data
  • We can pair data from different modalities to multiply the effective number of samples we can show our model, both by combining “matching” data (positive sampling) and intentionally showing mismatched data to show our model what is “wrong” (negative sampling)
  • Properties that cannot be learned from one Sense can potentially be learned through another
  • Performance gains in one modality may translate to others as well
  • There are potentially more ways to interrogate a model and understand “what it is thinking”

A Blueprint

But how can we make our multimodal dreams come true? The literature is full of countless possible architectures, but, in general, architectures tend to follow the rough blueprint of:

Reality -> Encoders/Senses -> Common Latent Space -> Model

We want a number of Encoders/ Senses, one for each input modality (or technically, we could have multiple variants per modality, such as one encoder for image luminosity and another for color, but this is not typical) that map our inputs into some “common language” (also called a “latent space” in this example), which the model can then operate on (such as by predicting future states).

You might be wondering why we would want a single unified “language” for different input modalities. Does it make sense to encode e.g. images and text into the same space?

Reality -> Encoders/Senses -> Separate Latent Spaces -> Model, an alternative way of doing things

Encoding everything into the same space is necessary for several of the benefits we’re looking for in our models, first and foremost the ability to compare and translate between inputs of different modalities. When you hear a dog bark, you can imagine a picture of a dog, after all. Combining the different modalities is half of the reason we’re even bothering with multimodality! And, arguably, even if one were to use separate spaces for different inputs, the model would eventually want to “mix” the different inputs internally, leading to an implicit version of a shared latent space anyways.

Implementing the Blueprint

While the rough outline of our plan seems pretty clear, what is far less clear is exactly which of the myriad proposed architectures and algorithmic details to use for our blueprint. I am willing to bet on one factor though: The backbone of our core model will most likely be a transformer. Transformers have been behind recent revolutions first in NLP, and recently many other adjacent fields as well. BERT, GPT3, DALL-E and other breakthrough results are all transformers.

But there are many other factors that are far less clear at the moment:

  • How should the various modalities be encoded into latents? Discrete values/Pixels (iGPT)? Tokens (DALL-E)? Continuous Vectors (AA experiments)?
  • What is the correct metric/loss for encoding? For example, MSE is often used on images, but may be far more sensitive to high frequency features in images than humans are.
  • There are many different types of training regimes, most notably autoregressive generation (predicting the next token given the previous tokens), masked generation (generate a “masked out” token in a sequence) and contrastive learning (learn to encode two incompatible samples “as far away from one another as possible”). What are the benefits and downsides of the different schemes?
  • How can “matched” datasets be collected at scale? Large text datasets have become commonplace, but matched text/images/sound and such far less so.

The Frontier: DALL-E and CLIP

OpenAI has recently published groundbreaking multimodal work that illustrates the concepts above nicely:

DALL-E

OpenAI is a pioneer in the field of large, unsupervised world model training. One of their latest successes is DALL-E, which is a wonderful example of the kinds of multimodal transformers we have been discussing. DALL-E is capable of producing extremely high quality and diverse images from a textual description. Unlike previous attempts using GANs and similar methods, DALL-E is a pure autoregressive transformer, only predicting tokens! The model is composed of two parts:

  • A Variational Autoencoder (VAE) with discrete latents trained on the images from the dataset. The VAE learns to encode an image into a number of discrete tokens that maximize its ability to reconstruct that image, therefore learning a kind of “vocabulary of image-pieces.” The VAE is trained ahead of time.
  • A simple token-predicting transformer, which predicts the relevant image tokens after seeing the textual description tokens. And…that’s it! DALL-E simply predicts the corresponding VAE tokens from the inputted textual description and the VAE then decodes these tokens back into an image. This simple method produces truly stunning results, check out the gallery!

CLIP

Paired with DALL-E, OpenAI has also recently unveiled their CLIP system. CLIP is built upon a deceptively simple principle: Encode a batch of text and images, combine each text with each image (creating n² samples), and test how well these pairings “fit”. CLIP learns to match images with their true text labels, while in turn minimizing the “closeness” to the other, non-fitting labels. This technique is called contrastive learning, and the benefits shown by CLIP are immense.

Functioning of CLIP

Not only does CLIP make for a pretty good caption recognizer (it can’t technically generate captions, only recognize which given captions fit which given images), but what makes it more interesting is its robustness. While there are other networks that do better than CLIP on the data they are specifically trained on, CLIP does amazingly well on data it hasn’t seen before.

CLIP off distribution results (grey bar is ResNet, blue is CLIP)

Using contrastive learning, CLIP can use its data far more efficiently, thanks to the negative sampling of combining images with “wrong” labels during training. This is inherently possible thanks to CLIP’s multimodality. This powerful technique is an important ingredient to achieving the hoped for benefits from multimodality.

Research at Aleph Alpha

We have been working on multimodality, contrastive learning and related methods for about two years now. In a March 2020 experiment (with multimodal transformers similar to DALL-E), we used a ResNet based encoder to encode images into a latent vector that we feed into the network alongside the text tokens.

A model like this is capable of capturing the contents and context of images beyond classical object detection. We further improved model transparency by visualizing the activations in the ResNet and the images patches with the biggest semantical impact on a per token basis. In the example below we see that for the token “television” the biggest impact here indeed is on the TV and the cabinet above.

In May 2020 we combined multimodal encoders with a method pioneered for multiple “views” of an object: contrastive multiview coding. Using contrastive coding, multiple “views” (encodings) of the same object are cast to a latent space so that they are closely clustered, and maximally far away from any other objects. At Aleph Alpha, we applied this same technique not to multiple types of images, but to image and text pairs.

A powerful advantage of this approach is the creation of a meaningful multimodal neighborhood. In the dimensionally reduced example above the “airplanes in the sky” forms a cluster separate from the “airplanes on the ground” cluster. These two clusters are themselves close to each other (as they both contain airplanes). These neighborhoods can enable powerful search, sort and analysis cases combining complex understanding of images and images or texts with more than one kind of content.

In a “truly” multimodal transformer from Nov 2020 we developed a combination of GPT-3 text token and iGPT pixel token in the same model. The model was capable of reading and writing images or text and combining the contents of these two modalities; turning simple textual descriptions into simple images (the results aren’t perfect, but it was a promising proof of concept!).

Connor’s Connection of (fore)Casts

As the saying goes, it’s hard to make predictions, especially about the future. Where many before me have failed, I will now strive to…probably fail as well by making some predictions of my own for what the near-term future of AI technology will hold:

  1. Large pretrained world models are the future of AI. I think the writing is on the wall in regards to pretraining large models and then fine tuning them for specific applications. As with GPT3 and DALL-E, I expect more of the cutting edge results to continue to be powered by this kind of architecture.
  2. Multimodality produces better world models (probably). It seems very likely to me that there are many useful properties we would want our models to learn that they can’t learn (easily, at least) exclusively from e.g. text (such as an understanding of physics, or the flow of time). Multimodal training seems like the obvious way to overcome this limitation.
  3. Any fancy dataset or architecture might not matter if scaling laws continue to hold. Even though multimodality will probably help with learning useful tasks, and for some tasks will be mandatory, there is a long history of simple methods dominating sophisticated/clever approaches when scaled up. If hardware continues to grow exponentially more powerful, and simple models such as transformers continue to scale, then… Attention May Really Be All We Need™.

So what would I suggest are the next steps for anyone interested in this field?

  1. Hardware is king. The techniques for building large SOTA models are widely known and available, the bottleneck remains the massive amounts of hardware required.
  2. Large multimodal datasets remain not widely available. There is no doubt that the huge datasets that went into the training of models such as DALL-E and CLIP are a large part of what makes them successful. Datasets at this scale are at the time of writing not easily accessible.
  3. Translating pretrained models to valuable use cases (e.g. through fine tuning or RL) is still an underexplored area. Powerful, generic, and multimodal systems are cutting edge developments and their practical uses are still underexplored. The GPT3 API and the many startups sprouting up around it are only the beginning!

This article was originally published on https://aleph-alpha.de.

--

--

Aleph Alpha
Aleph Alpha Blog

We are an independent European company researching, developing and operationalizing a new foundational AI technology for the public and private sector