AlphaFold3 and its improvements in comparison to AlphaFold2

Falk Hoffmann
28 min readMay 10, 2024

--

Are you curious about what this text holds for you?

  • A short history of protein structure methods
  • Short overview of the performance of AlphaFold3 in comparison to other SOTA methods on various complex structure prediction tasks
  • A step-by-step guide through the AlphaFold3 model architecture and its improvements in comparison to AlphaFold2
  • Technical details about the implementation of AlphaFold3

History of protein structure prediction methods

Predicting protein structures is one of the most important challenges in biochemistry. Highly accurate protein structures are essential for drug discovery purposes. Here, you can find more about the background. Protein structure prediction began in the 1950s with the advent of computational methods and the growing understanding of protein structure. It was initially dominated by physics-based methods and theoretical models. The available computational power at this time limited those models, and they were not very successful in predicting the structure of most proteins. The next wave of protein structure models were homology models, which emerged in the 1970s. Those models rely on the effect that the structure of proteins with similar sequences is also similar. Multiple sequence alignments of the target sequence to template sequences with available structures were used to successfully identify structures of sequences for the first time, which hadn’t been resolved before. However, the resolution of those models was still limited. The advent of ab initio methods in the 1980s introduced the next resolution boom, which applied physics-based methods and optimisation algorithms.
Combined with computational advances, this led to significant improvements in protein structure prediction. To benchmark all those new methods, the Critical Assessment of Techniques for Protein Structure Prediction (CASP) series was initiated in the early 1990s. In recent years, machine learning and deep learning techniques have been increasingly integrated into protein structure prediction methods, especially since using long short-term memory in 2007. Google’s DeepMind introduced AlphaFold in the 13th CASP competition, which was held in 2018. AlphaFold uses a neural network approach that directly predicts the 3D coordinates of all non-hydrogen atoms for a given protein using the protein sequence and sequence homology. While immediately being the best prediction method in CASP13, it was still not good enough to call the protein structure prediction task successful. A successful in silico prediction method reaches experimental accuracy. Such precision was reached with AlphaFold2, which was introduced in CASP14 using a reinvented model architecture. It consists of an Evoformer module and a structure model, shown in the following figure.

The network architecture of AlphaFold2. Figure extracted from the original paper.

Soon after AlphaFold2 was published, models with similar architectures appeared to either improve its accuracy or make it faster. Some examples are RosettaFold from the lab of David Baker, OmegaFold or ESMFold from MetaAI. In collaboration with the European Bioinformatics Institute, DeepMind also created the AlphaFold — EBI database, which contains more than 200 million predicted structures of proteins covering most of the sequences in UniProt. Many biotech companies frequently use those predictions to design new drugs.

Quick view on the AlphaFold3 performance

On May 8th, 2024, Isomorphic Labs and DeepMind published AlphaFold3. While AlphaFold3 is more accurate than AlphaFold2 in predicting single protein structures, AlphaFold3’s main advantages are its more precise prediction of protein complexes and its extension of its applications from proteins to other molecules, containing nearly all molecules in the Protein Data Bank (PDB). For example, AlphaFold3 outperforms classical docking tools like state-of-the-art (SOTA) Vina and recent ML tools like RoseTTAFold All-Atom on protein-ligand interfaces from the PoseBusters benchmark, which contains 428 protein-ligand structures from the PDB released before 2021. AlphaFold3 also receives higher accuracy in the prediction of protein-nucleic acid complexes and RNA structures in comparison to SOTA RoseTTAFold2NA and AIchemy_RNA, the best AI submission of CASP15, on the CASP15 examples and a PDB protein-nucleic acid dataset. On the CASP15 benchmark, the best human-expert-aided AIchemy_RNA2 was slightly better than AF3.

Furthermore, AlphaFold3 also more accurately predicts the effect of covalent modifications like bonded ligands, glycosylation, modified protein residues and nucleic acid bases on proteins, RNA or DNA. However, no comparison to other tools is reported. The number of good predictions, measured by a pocket RMSD of less than 2Å on high-quality experimental datasets, varies between 40% for RNA-modified residues and nearly 80% for bonded ligands. The limited number of examples in the datasets generates a relatively high statistical error on those values. AlphaFold3 also improves the prediction of protein-protein complexes, with a significant improvement in antibody-protein interfaces compared to AlphaFold-Multimer v2.3.

AlphaFold3 prediction of the 8AW3 — RNA modifying protein. Figure from Google’s blog.

We will probably see many additional benchmarks on AlphaFold3 in the coming months, but the reported improvements on different tasks are promising.

General AlphaFold3 architecture

The following figure shows the overall model of AlphaFold3.

The network architecture of AlphaFold 3. Figure extracted from the original paper.

The structure of the model is similar to the structure of AlphaFold2, but many steps have been improved for a more accurate prediction of protein structure (complexes). Similar to AlphaFold2, a template search and a genetic search are performed. Both, in addition to the output of a conformer search, are used as inputs to the Template module and MSA module. The MSA module is smaller than in AlphaFold2 (see below why). A Pairformer module replaces the Evoformer module of AlphaFold2. This module only processes the single and pair representations but not the MSA representations. The structure model in AlphaFold2 is replaced by a Diffusion model. As a generative method, diffusion gives a distribution of structures instead of a single structure with uncertainty, making the final predictions exact and avoiding parametrisations. Importantly, no physics-based minimisation is needed as performed with AMBER on the side chain atom positions in AlphaFold2. To prevent generating physically plausible structures in unstructured regions during the diffusion process, cross-distillation was used with training data from AlphaFold-Multimer v2.3, which contains loops in those regions. Finally, a confidence module is introduced to measure the confidence of the errors from the atom-level and pairwise calculation.

Now, let’s go through all these modules step by step to understand what they are doing and how they contribute to improving the prediction of protein structures compared to their predecessors from AlphaFold2.

AlphaFold3 algorithm

Similar to the structure in AlphaFold2, the architecture in AlphaFold3 relies on transformers. If you need clarification on how transformers work, read about them before you continue reading this text, for example, my previous text about transformers for small molecule property predictions. Most of those concepts will show up in this text. The main algorithm is the following.

Algorithm of AlphaFold2. Extracted from the original paper.

The algorithm describes a conditional diffusion model. Let’s go through it step by step to understand what is happening.

Step 1: Input tokens and embeddings

First, the input molecules (protein, RNA, DNA, small molecules) must be converted into a mathematical form. Molecules are described by atoms and bonds connecting those atoms. However, proteins, RNA and DNA have a regularised general structure (e.g. proteins consist of amino acids, and DNA and RNA consist of nucleic acids), which makes the representation of those entities in atoms too detailed. Therefore, different tokens for different molecules are used. AlphaFold2 was a pure protein structure prediction tool which did not consider non-protein binding partners. There were 23 tokens in AlphaFold2: One token for every of the 20 standard amino acids, a token for an unknown amino acid, a gap token and a masked Multiple sequence alignment (MSA) token. In AlphaFold3, RNA, DNA and general molecules are also considered. For DNA and RNA, tokens correspond to entire nucleotides. A token is represented by a single heavy atom for all other general molecules.

Those tokens have features which have to be embedded. The following features are used in AlphaFold3:

  • Protein features are the residue number residue_index, the token number token_index (increases monotonically from the start token of the input), the chain index asym_id, the sequence index entity_id, the IDs sym_id of chains which share a sequence, the residue type restype (20(+1 unknown) amino acids, 4(+1 unknown) RNA nucleotides, 4(+1 unknown) DNA nucleotides, gap residue) and a token is_protein / rna / dna / ligand specifying the type of molecule.
  • Reference conformer features are the atom positions ref_pos after applying random rotation and translation (3 values given in Å), the mask token ref_mask giving the used atoms in the conformer, the atomic number ref_element of the element of the atom, the charge ref_charge of the atom, the unique atom name ref_atom_name_chars and a unique ID ref_space_uid of every combination of chain ID and residue index.
  • The MSA features are the encoding msa of the processed MSA, a binary feature has_deletion saying if there is a deletion to the left, the normalised number of deletions deletion_value to the left, the distribution profile of the main MSA across residue types and the mean number of deletions deletion_mean in the main MSA.
  • The template features are the template sequence template_restype, mask tokens template_pseudo_beta_mask and template_backbone_frame_mask specifying if there is a CB atom at this residue and if coordinates exist for all backbone atoms at this residue, respectively, a pairwise encoding template_distogram of all CB atom distances and a vector template_unit_vector with the CA atom displacements of all residues.
  • Finally, there is one bond feature, token_bonds, a 2D matrix specifying if a bond between two tokens is present for all polymer-ligand and ligand-ligand bonds. Bonds are present if they are within 2.4 Å.

Those features have to be embedded. This requires contextual embedding (here via the input embedder) and positional encoding.

Here, f are the features mentioned above. In AlphaFold3, two types of embeddings are performed.

  1. The bond feature provided by bonds specified by the user via the token_bonds feature is linearly embedded. They pass through a linear layer containing a linear transformation with a weight matrix W and a bias vector b.
  2. The residue type, reference conformer, MSA profile and deletion_mean features are embedded with the algorithm shown below. This means that all converted per-atom features use an AtomAttentionEncoder, and the per-atom feature outputs of the encoder are then concatenated with the restype, MSA profile and deletion_mean features. The last two can be computed before MSA processing as they are only applied on the main MSA.

But what does this AtomAttentionEncoder do? Let’s go through the algorithm step by step.

All-atom features are first concatenated to a big matrix and then go through a linear layer without bias, meaning they are multiplied by a weight matrix. This creates output vectors cₗ for all Nₐₜₒₘₛ atoms in the ligand molecule. Those vectors have a length of cₐₜₒₘ = 128, e.g. 128 arbitrary features (real values) per atom coming from the input features of the molecule.

Relative distances are calculated between all combinations of two atoms in the reference conformer. Those distances are multiplied by a weight matrix. The result is added to pₗₘ if atoms l and m originate from the same chain ID and residue index, which means that only inner residue distance offsets are calculated.

Pairwise inverse squared distances are calculated and multiplied with a weight matrix. The result for the same residue is embedded. In addition, the mask (1 if the distance belongs to the same residue, 0 if not) is also embedded after multiplying with a weight matrix.

Save the single-atom representations for further manipulation.

The trunk embeddings (see below) for the token index of the token at position l undergo a layer normalisation (by effectively subtracting the mean and dividing it by the standard deviation of all values). They are then multiplied by a weight matrix. The result is added to the atom single representations. Similarly, all pairwise embeddings of the token indices belonging to tokens l and m are added to the pairwise embeddings pₗₘ after applying layer normalisation and multiplication with a weight matrix. The noise rₗ is added after multiplying with a noise weight matrix. Note that all steps are only performed if a noise rₗ is given, which is not the case for the 1D embeddings above.

The representations of single atoms l and m also influence their interaction. After applying a ReLU activation function and multiplying it with specific weight matrices, the sum is added to the pairwise embeddings.

The pairwise embeddings pass through a multilayer perceptron consisting of three layers with ReLU activation function and weight matrix without bias.

The previous steps have altered the embeddings of the single atoms cₗ, especially by including the trunk embeddings (step 9). To not lose the information from the original embeddings and to include all effects from pairwise embeddings with all other atoms m (mainly from the same residue), an AtomTransformer is applied, which uses multi-head cross-attention with three blocks and four heads per block applied on the input represented by the unmodified single atom embeddings qₗ, the modified single atom embeddings cₗ and the pairwise embeddings pₗₘ. The AtomTransformer performs the actual conditional diffusion, which will be described below.

The final single-atom representations pass another linear layer without bias and with the ReLU activation function. The embeddings will be converted into scores aᵢ by taking the mean over all atoms that belong to token i. Finally, the embeddings of token i and the single atom and pairwise embeddings are returned by the AtomAttentionEncoder. In the case of the input embeddings above, only the token embeddings are used and concatenated with the restype, MSA profile and deletion_mean features to give the input embeddings sᵢ of token i.

Steps 2 and 3: Linear layer and pairwise embeddings

The input embeddings calculated in the first step are multiplied by a weight matrix to generate the initial single token embeddings sᵢ of token i. Initial pairwise token embeddings zᵢⱼare generated from the single token embeddings i and j after multiplying with specific weight matrices. Note that the single token embeddings have a length of 384, while the pairwise embeddings have a shorter length of 128.

Step 4: Relative positional encoding

The pairwise embeddings calculated above contain information about the token based on its features. However, they do not contain information on the order or position of the token in the input sequence. This is included in the positional encodings. AlphaFold2 and 3 use relative positional encodings between two tokens. Relative positional encodings have already been used in AlphaFold2 and have the advantage that the quality of the (AlphaFold) model does not decrease on sequences longer than those on which the model was trained. Without relative encodings, identical residues or chains at different positions of the input sequence could have the same embeddings.

Relative positional encodings are calculated using the following algorithm in AlphaFold3.

Algorithm of relative position encoding extracted from the Supporting Information of the AlphaFold3 paper.

Steps 1, 2 and 3 identify token pairs i and j from the same chain, with the same residue index and even with the same entity, respectively.

Steps 4 and 5 calculate the one-hot encoding of the relative position of two tokens, i and j, in the input sequence (calculated on their residue number residue_index) of the same chain (step 1). This relative position is clipped by rₘₐₓ = 32, meaning that residues that are longer than 32 residues away in the same chain are not distinguished by their position. This clipping has already been used in AlphaFold2. It reduces the effect of primary sequence distances; in other words, it emphasises the impact of other input feature differences between tokens, as calculated above.

Steps 6 and 7 calculate the one-hot encoding of the token index difference between tokens i and j of the same chain (step 1) and residue (step 2) using the same clipping cutoff rₘₐₓ = 32. Steps 9 and 10 calculate the same one-hot encoding on the chain index difference between tokens i and j if they are part of different chains (step 1). In this case, a cutoff sₘₐₓ = 2 is used, as AlphaFold tries to predict structures of complexes containing two different chains. The chain index difference between tokens is not encoded if more chains exist.

Finally, the one-hot encodings of the relative residue number difference, the relative token index difference and the relative chain index difference between tokens i and j and a mask indicating if those tokens have the same identity are concatenated and multiplied by a weight matrix to generate the relative positional encodings of tokens i and j. These are added to the pairwise token embeddings calculated above to create new pairwise input embeddings.

Step 5: Linear layer

So far, all features are embedded except user-defined bonds between tokens. As mentioned above, the user of AlphaFold3 has the option to specify bonds between tokens i and j (e.g. heavy atoms in non-protein/RNA/DNA molecules) via the token_bonds feature. This feature is added to the pairwise embeddings by multiplying with a weight matrix.

This step generates the initial pairwise embeddings. Together with the single token embeddings from Step 2, we can move on to the first modules of AlphaFold3.

Step 6: Initialization

Pairwise token and single token embeddings are initialised initially to the 0 vector and will be updated in every round of the following cycle.

Step 7: Iterations

This is not an actual step, but it tells us that the following steps 8 (Linear Layer and Layer Normalization), 9 (TemplateEmbedder), 10 (MsaModule), 11 (another Linear Layer and Layer Normalization), 12 (Pairformer Stack) and 13 (Value update) are performed for N_cycle = 4 recycling iterations.

Step 8: Linear Layer on pair representations

The updated pairwise token embeddings from the previous cycle (0-vector in the first cycle) pass through a layer normalisation and multiplication with a weight matrix at the beginning of every new recycling iteration. The outputs are added to the initial embeddings from step 5 to generate pairwise embeddings for this cycle.

Step 9: Template Embedder

The template embedder uses the generated pairwise embeddings of the previous step 8 and features from the template search as input. The main task of the template embedder is to attend to those regions in the template with higher weights based on the current values of the pairwise embeddings. Updating those pairwise embeddings shifts the focus to the “more important” regions of the template structure.

The features are extracted from a template search for individual protein chains of the UniRef90 MSA from the input sequences. Very long sequences are cropped to the first 300 residues. Hidden Markov models (HMM) are generated from the MSA using hmmbuild, and templates are generated from the result using hmmsearch. Short templates with less than ten residues and templates with less than 10% or more than 95% sequence coverage to the query are removed. From the remaining templates, four are used during inference and up to four during AlphaFold3 training based on the e-value. Structural data for the sequence is extracted from PDB70 or, if the sequence of the corresponding mmCIF of the PDB database does not match exactly, after alignment using KAlign. In summary, the template search in AlphaFold3 has remained the same compared to AlphaFold2 except for some details like the cutoff dates for the templates.

Features are extracted from the templates’ structure and fed into the template embedder together with the pairwise embeddings of the previous cycle.

Template embedder algorithm extracted from the AlphaFold3 paper.

In steps 1–5, the template_backbone_frame_mask, template_distogram, template_restype, template_pseudo_beta_mask and template_unit_vector features are concatenated. Please check the definitions in step 1 of the main algorithm above. The corresponding embeddings contain all information about those features a_tij for all t templates and all containing tokens i and j. Those feature embeddings are multiplied by a weight matrix and added to the pairwise embeddings of the previous step after those are normalised and multiplied by a weight matrix. The concatenated embeddings go through two blocks of the PairformerStack, described below, which replaces the Evoformer module in AlphaFold2. Results are added to the unmodified concatenated embeddings via residual connection to keep the features before modification with the Pairformer model and then normalised. After doing this for all template structures, the output embeddings are normalised by the number of template structures and go through a linear layer with the ReLU activation function.

In summary, the template embedder uses the current pairwise embeddings of every cycle to attend to the currently most essential regions in all templates. By applying this to all templates, the structural focus is moved to the structural changes in those regions of the protein with higher weights in the pairwise embeddings. The template pair stack in AlphaFold2 followed a similar concept of attending pairwise features to template structures but was built upon different layer organisations.

Step 10: MSA module

The MSA modules’ task is to generate a new subset of the MSA in each recycling iteration. The structure of the MSA module is shown in the following figure.

MSA module of AlphaFold3. Figure extracted from the original paper.

The MSA module calculates new pairwise representations using the features extracted from an MSA subset, the pairwise representations from the output of the Template module and the single token representation as input.

The MSA module consists of four blocks, much less than in AlphaFold2, which consists of 48 blocks, similar to the pair stack block in AlphaFold2. However, this is not the only change in the MSA module. The most important technical difference is that the MSA module of AlphaFold3 does not use row-wise gated self-attention. To understand this, let’s look at how the MSA is generated.

The MSA in AlphaFold3 consists of 16,384 rows. The first row is the query (input) sequence. The following 8191 rows (or less, if fewer alignments are found) are constructed by copying the MSA n times for homomeric complexes with n as the number of chain repeats or by stacking MSAs of each chain from left to right after pairing the sequences for heteromeric complexes. The remaining unpaired MSA sequences are added below but capped after max. 8191 rows. The other half of the MSA rows are filled with the original MSA. For the genetic search, Jackhmmer and HHBlits are used to search in the UniRef90, UniProt, Uniclust30 + BFD, Reduced BFD and MGnify protein databases and mmseqs is used to search in the Rfam, RNACentral and Nucleotide collection RNA chain databases.

At the end, different MSA sequences are in the rows of the MSA matrix, while aligned residues are in the columns. Applying row-wise gated self-attention in AlphaFold2 generates attention weights for residue pairs. Pair embeddings are included as an additional bias. In AlphaFold3, the attention is performed independently for each row. This also means that the attention weights are generated from the pair embeddings. In other words, this change focuses more on the pair representations than the MSA between residue pairs. However, those pair representations contain some information from residue pairs within rows of the MSA coming from the input embeddings.

What is the effect of this change? Well, the MSA row attention attends to different residue pairs of the same sequence. This is represented by the features leading to the interaction between the tokens corresponding to those residues. Those features should be encoded in the pair representation of those residues, so it makes sense to let the complete information flow through the pair representations.

Let’s have a look at the MSA module algorithm.

First, all MSA features are concatenated with the has_deletion and the deletion_value features, which tell us if and how many deletions are to the left in the MSA (see explanation in Step 1 of the main algorithm). Then, a random number of MSAs are selected, and the corresponding representations of all its tokens are multiplied with a weight matrix. Via residual connection, the input single token representations are added those have been multiplied with another weight matrix. Those representations enter the MSA block. An OuterProductMean layer consisting of layer normalisation, multiplication with a weight matrix, averaging, flattening and a linear layer is applied to those representations.

The results enter the MSA stack and the pair stack. The MSA stack performs a pair-weighted averaging on those embeddings, followed by a dropout of 0.15 on the MSA rows. This ensures that new MSA subsets at every new execution are included in the embeddings. Finally, a transition layer with normalisation, linearisation and SwiGLU activation is applied to generate the input for the following MSA stack or outer mean product with the pair representations.

The pair stack contains two triangular multiplications, two triangular attentions and a transition layer. Those triangular updates were introduced in AlphaFold2. They ensure that pairwise representations between tokens (in AlphaFold2, only amino acids) can be represented as a 3D structure with constraints like the triangle inequality of distances. In short, update operations are arranged in triangles of edges with three nodes. The missing node is added via axial attention, updated in the triangle multiplicative update steps and finally in the attention steps. As this pair stack step has remained the same as in AlphaFold2, I refer to the AlphaFold2 paper for further details about the triangle layers.

After passing four blocks of the MSA module, the pair representation output of its last pair stack is passed through the Pairformer stack.

Step 11: Linear layer on single representations

The Template and MSA modules extracted information about the relationship between tokens, represented by the pair representation. Consequently, only the pair representation was modified in those modules. Those modules have not updated the single token representation. The single and pair representations are used in the Pairformer and recycled during the different iterations. This requires an update of the single representation using the previous iteration's output. This is done via a Layer normalisation and multiplication with a weight matrix, applied to the representations after the previous iteration or to the 0 vector in the first iteration cycle. The output is then added to the initial single token embeddings from the Input embedder (Step 2) to generate new inputs directly used in the following Pairformer without further modifications. This step mirrors Step 8 for the pair representations.

Step 12: Pairformer

The Pairformer module in AlphaFold3 is similar to the Evoformer module in AlphaFold2. It uses the single representations from the previous iterations added to the single representations from the input embedder and the pair representations from the MSA module as input. Compared to the Evoformer module in AlphaFold2, a specific MSA representation input is not used. The information from the MSA is already included in the pair representations (see Step 10).

The Evoformer in AlphaFold2 uses a subset of the MSA as a representation and column-wise attention is applied. This is not needed in the Pairformer module. As only the single representation is used, only row-wise attention on a single sequence with the embeddings from the single representations is applied. Let’s look at the Pairformer stack algorithm.

Pairformer stack. Figure extracted from the AlphaFold3 paper.

One can directly see that the steps are similar to those performed in the MSA module. Like in the MSA module, the pair representations pass two triangular update, two triangular self-attention, and one transition layer with SwiGLU activation function. In all layers, residual connections are applied. The single representations do not influence the pair representations. Intuitively, this means that the features representing the interactions between a pair of tokens (e.g. a pair of amino acids or nucleotides) are influenced by the pair features but not by the single token features. However, features representing the interaction with other tokens influence the features of a single token. This is defined by the single Attention layer with pair bias through which the single representations are passed. This layer has inputs from the single representation and the output of the pair representation of this block. It applies multi-head self-attention with 16 heads on the single representations using the pair representations as bias. The output passes another Transition layer, and residual connections are applied to both layers. The outputs of the transition layers from the pair and single stack are used as inputs for the following block. Like AlphaFold2, 48 blocks are applied, making the Pairformer the main processing module in the AlphaFold3 architecture.

Step 13: Value updates and recycling

The single and pair representation outputs from the Pairformer module are either the inputs for the following recycling cycle or the input of the following Diffusion and ConfidenceHead modules.

Step 14: End of iteration

The Pairformer module marks the end of an iteration in the AlphaFold3 process. The output is recycled for the following input. Four such cycles are passed with a constant update of the single and pair representations using new subsets of MSA. At the end of all cycles, representation for single tokens and pairs of tokens are learned and can be used for structure generation in the following Diffusion module.

Step 15: Diffusion module

The Diffusion model generates a distribution of structures and replaces the Structure module in AlphaFold2. The Sample Diffusion takes the single and pair representations from the Pairformer module, the features and the single input embeddings from the InputEmbedder as input and generates coordinates for the atoms and/or tokens.

As this module is new in AlphaFold3, let’s look at the SampleDiffusion algorithm in detail to understand what is done.

First, initial positions are generated from a Gaussian distribution around the origin. Then, different noises are applied. For all those noises, the positions of the atoms are centred and rotated, and a noisy displacement is added. Those noisy positions enter the Diffusion module. The denoised positions estimated from this module are used to update the atom positions. This is repeated for T different noises and samples the final set of atom positions.

The Diffusion module, which denoises the atom positions, is shown in the following figure.

Diffusion module. Figure extracted from the AlphaFold3 paper.

Let’s look at the algorithm of the Diffusion model step by step.

First, the diffusion is conditioned using the input calculated in the previous steps. The pair embeddings are concatenated with the relative positional encodings of the input features, normalised, multiplied with a weight matrix and passed through two transition layers with the SwiGLU activation function and residual connection. The single representations from the Pairformer output and the input embedder are also concatenated, normalised and multiplied with a weight matrix. But then Fourier embeddings are used with a cosine function. In my previous text about transformers, I explained why trigonometric functions are used for positional embeddings. The embeddings are again normalised, multiplied with a weight matrix and passed through two Transition layers with residual connections.

Second, the noisy input positions are scaled to generate dimensionless vectors with nearly unit variance.

Third, the rescaled positions, features, single embeddings and pairwise embeddings pass the AtomAttentionEncoder described in Step 1 of the main algorithm. Note that noisy positions are given as input compared to Step 1 of the main algorithm. This means that the encoder's single and pairwise embeddings and noisy positions are updated with linear layers and residual connections (see description above). In addition, the encoder generates new positions qₗ from the Atom Transformer using multi-head self-attention.

The Atom Transformer uses rectangular blocks along the diagonals of the pairwise matrices for local atom attention in the sequence. This ensures that only atoms in the vicinity are attended. With those attention masks, a Diffusion Transformer of 24 blocks is called, which uses Adaptive LayerNorm for the single conditioning and logit biasing for the pair conditioning. I will write a separate article about diffusion transformers and link this text here, but this Diffusion Transformer does the magic to denoise the structure in this step on the atom level.

After conditioning on atoms, self-attention is applied on the token level using a similarly structured Diffusion Transformer. The final token representations are normalised.

Next, the token embeddings are used in an Atom Attention Decoder. This decoder first converts the token activations into atom representations and then uses the Atom Transformer from above to generate new positions.

The updated positions from the second Atom Transformer are rescaled to dimensionless vectors with nearly unit variance and then added to the rescaled noisy atom positions to generate the new denoised atom positions.

In summary, the Diffusion model of AlphaFold3, which replaces the Structure model in AlphaFold2, predicts true atomic coordinates from noisy atomic coordinates with a standard diffusion model without rotational frames and equivariance. Different noise levels are used to learn protein structures on a local level (less noise) and larger length scales (high noise). The diffusion model generates a distribution of structures, which allows the prediction of accurate side-chain geometries. However, such a distribution can generate physically plausible structures in unstructured regions. To ensure that AlphaFold3 is confident about the predicted denoised structures, the output of the Diffusion model enters the input of a Confidence module.

Step 16: Confidence Head

The Confidence module uses the single inputs from the Input embedder, the single and pair representations from the last iteration of rhe Pairformer module and the predicted coordinates from the Diffusion module to estimate a confidence level of the structure prediction. This is important, e.g., to filter out plausible structures in unstructured regions.

Let’s look at the Algorithm of the Confidence Head.

The Confidence Head module starts to predict the confidence of the position of tokens by multiplying the single embeddings of two tokens, i and j, with weight matrices. After adding both outputs, the result is added to the pair embedding of both tokens. The distance between representative atoms of both tokens is one-hot encoded based on the fact that they are within certain distance thresholds. After multiplying with another weight matrix, this information about the distance is also included in the pair embeddings. In other words, the pair embeddings after step 3 contain the information if representative atoms of both tokens are within a certain distance threshold and contain the single token embedding information of the input embedder. Those representations go through 4 blocks of the Pairformer module with residual connections and are updated accordingly. The updated single and pair embeddings are multiplied with weight matrices. Then, the softmax function is used to predict if the token or atom is close to the ground truth based on four metrics: the predicted local distance difference test (pLDDT) on individual atoms, the pairwise atom-atom aligned error (PAE), the predicted distance error (PDE) on pairwise atom-atom distances and resolution to the experimentally resolved ground truth on individual atoms.

Step 17: Distogram Head

In addition to the confidence metrics in the previous step, AlphaFold3 calculates the Distogram of predicted binned distances between all pairs of tokens. A representative atom is chosen to calculate the distogram for tokens containing more than one atom. That is the CB atom for all protein residues except Glycine, the CA atom for Glycine, the C4 atom for purines and the C2 atom for pyrimidines.

The prediction of the distogram first symmetries the pair representations, e.g., it calculates z_ij + z_ji. The result is projected linearly into bins with equal width from 2 Å to 22 Å. A softmax function is then used to get the probabilities p_ij.

Step 18: Predictions

AlphaFold3 returns the predictions of the coordinates of the generated structures, the distogram probabilities and the confidence values based on the pLDDT score, the PAE score, the PDE score and the score to match experimentally resolved structures.

Summary

Similar to the previous versions of AlphaFold, AlphaFold3 is the new SOTA method for protein structure prediction. It is slightly better at single-chain predictions than AlphaFold2, but the main advantage of AlphaFold3 is in the accurate prediction of complexes. Especially remarkable is the improved prediction of antibody-antigen interface predictions in comparison to AlphaFold-multimer v2.3.

Those advancements are based on a new architecture in AlphaFold3. While AlphaFold3 builds up on AlphaFold2 and the difference between AlphaFold3 and AlphaFold2 architectures is much less pronounced than between AlphaFold2 and ALphaFold1, significant modifications have been made. They include:

  • The replacement of the Structure module in AlphaFold2 by a Diffusion module in AlphaFold3. Remarkable is that the prediction is improved without introducing invariant or equivariant constraints
  • The expansion of the vocabulary from amino acids representing proteins to nucleotides representing RNA and DNA and heavy atoms representing all chemical molecules, including ligands
  • The MSA module is much smaller (only four blocks) in AlphaFold3 than in AlphaFold2, and it has been removed from the new Pairformer module.
  • The Pairformer module in AlphaFold 3 replaces the Evoformer module in AlphaFold2. While the structures of both modules are similar except for the removed MSA module (see last point), the Pairformer module has some additional internal changes. For example, information flows from the pair to the single representations but not in the other direction.
  • Most but not all ReLU activation functions in AlphaFold2 are replaced by SwiGLU activation functions to improve performance.

Some of those modifications are needed to accurately predict protein complexes (e.g. the focus on pair representations), while others are updates to Machine Learning (ML) approaches that have emerged and spread in the last two years since AlphaFold2 was published (e.g. the diffusion model or SwiGLU activation functions). It is interesting to see how open-source models like OpenFold will use those models to improve their accuracy. We can expect several new complex prediction tools soon, most of which will use approaches implemented in AlphaFold3.

Resource:

Abramson, J., Adler, J., Dunger, J. et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3. Nature (2024). https://doi.org/10.1038/s41586-024-07487-w

--

--