GNNs and atomic coordinates: predicting a protein’s conformational landscape

Leah Reeder
Stanford CS224W GraphML Tutorials
14 min readMay 22, 2023

By Leah Reeder and Xun Tang for CS224W course project Winter ‘23.

Motivation

Titles like AlphaFold is The Most Important Achievement in AI — Ever do wonders to indicate how vital computational biology, and more specifically, protein structure prediction is as a research area especially as computational power grows.

Photo by ANIRUDH on Unsplash

Protein structure prediction is the problem of determining the structure of a protein from its amino acid sequence. Current methods for protein structure prediction include imaging, e.g. X-ray crystallography, cryo-EM, to molecular dynamics simulation, e.g. ab-initio molecular dynamics.

In the recent decade, machine-learning-based structural prediction has emerged as one of the most important tools for structure prediction of proteins, with one example being AlphaFold from DeepMind. Oftentimes these prediction methods give an average ‘solved’ structure for the protein. However, it is also important to be able to sufficiently model a protein’s conformational range, i.e. the different shapes a protein can take on depending on its function. The conformation that a protein is in can determine whether or not it will bind to a receptor, indicating how important this knowledge is for problems like drug-discovery.

Photo by Giovanni Crisalfi on Unsplash

In this blog post, we want to show how one can use graph neural networks (GNNs) to learn to characterize a protein’s conformational range.

We will demonstrate how GNNs can be used to to help us learn a protein’s conformational range, namely to discover the latent space representation of a biomolecule using data from molecular dynamics (MD) simulations. We will be using an equivariant GNN that uses the torch-geometric (PyG) package. The PyG package is a tool for working with graph-structured data in Python. It provides an intuitive API for constructing and manipulating graphs, as well as provide methods for graph-based computations.

Colab

Make sure to follow along with our associated Google Colab notebook. Note that to train the network it might be necessary to download it and use it locally.

Setting

Molecular abstraction

A graph G = (V, E) consists of a set of vertices or nodes V and a set of edges E, where each edge is a pair (u, v) representing a connection between nodes u and v. Node-level features of a graph refer to the attributes or characteristics associated with each individual node in the graph. In our case, the graph represents a biomolecule. Each node represents individual atoms within the molecule, and the edges represent the chemical bonds between those atoms. In this context, the graph would be a “molecular graph”.

The node-level features we use will be the location of atoms using just the backbone structure (i.e. nitrogen, carbon, carbon-α, and oxygen). As an introduction we will be investigating a pentapeptide, a chain of 5 amino acid residues including a combination of Tryptophan, Leucine, and Alanine. Proteins are generally a lot larger (50+ amino acids), so this gives a good proof of concept. All calculations below can be easily replicated on a larger protein provided sufficient computational power is available.

Amino acid structures from Wiki

Data abstraction

For a molecule with N nodes, its location at each time snapshot can be represented as a vector with length 3N, i.e. one value for each of its x, y, and z atomic coordinates per atom. We have access to time series of molecular dynamics (MD) simulation in the form:

where T is the number of timesteps. These simulations represent the collection of atomic positions over time. Oftentimes MD simulations will be performed numerous times and then aggregated over. As two MD simulations will never be the same due to the fact that there is inherent randomness in how atoms and molecules move over time, it is valid to chain independent trajectories together like this to simulate more atomic positions.

Problem Statement

We are interested in a latent space embedding map from the atomic coordinates to some lower dimensional space. Instead of using standard principle component analysis (PCA)-like dimensionality reduction on the location time-series data, we want the learned encoder to be useful for studying the kinetics of molecules. To do this, we consider using time-lagged data. For time-series data with lag τ, we want the encoder E to be such that there exists a good decoder D where:

Eq (1)

We will used time-lagged independent component analysis (TICA) to accomplish this task. Here the lag time τ is a parameter that can be fine-tuned, but it generally is flexible as long as it is not too short or too long.

Note: it is common, as seen in MSMBuilder and PyEMMA (two Python packages that use MD simulation data) to use TICA as a dimensionality reduction method on time series (i.e. trajectory) data.

So… What is time-lagged independent component analysis (TICA)?

There is no essential difficulty in explaining TICA once we take a deeper look at the formulation in Eq (1) above. It is based on the idea that the underlying dynamics of the system can be captured by a small number of independent components. One of the key advantages of TICA is its ability to identify slow collective motions that are important for understanding the behavior of the system.

In particular, if we let the encoder (mapping from 3N d, where d is a small dimension), and the decoder (mapping from d → 3N) both be linear maps, then the prediction task amounts to minimizing Eq (1) over all of the time trajectories. Moreover, the TICA formulation assumes that we are interested in a reversible Markov chain, which is why we also want

Eq (2)

Combining everything together, this amounts to finding matrices D and E which minimizes the following optimization program:

Eq (3)

One can see that this is nothing but a linear regression task, and therefore finding D, E can be easily implemented.

TICA works by first estimating the covariance matrix of the input data, which represents the statistical relationships between the variables at different time lags. Next, it applies singular value decomposition (SVD) to this covariance matrix to extract the dominant eigenvectors, which are the independent components that capture the most variance in the data. Finally, TICA projects the original data onto these independent components to obtain a lower-dimensional representation of the system dynamics.

We will implement a few functionalities to implement TICA, in three steps: [a] centering of data, [b] whitening the data, and [c] singular-value decomposition of processed time-lagged data. Details can be seen in the code in our Colab implementation.

Extension to Deep Autoencoders: The use of autoencoders in combination with TICA is also straight-forward. Let a function f map from 3N k be any generic map which maps protein location data to k-dimensional features. Then, one can use TICA on these transformed time-series data!

Additional detail on TICA: Explicitly, TICA calculates the time-lagged covariance matrices. One then uses the value of these to solve a generalized eigenvector problem. The top-d eigenvalues will be used as linear features to the time snapshots. More details can be found here. TICA is able to keep relevant information in slower processes, whereas another common dimension reduction technique such as principle component analysis (PCA) does not. This is why (as mentioned above), when constructing Markov state models from MD simulation data it is common to use TICA over PCA to preserve the slow dynamics.

Key insight: Deep autoencoder with GNN equivariant layer

In particular, it is important that we use equivariant node feature maps. This is so that if the data is rotated, the feature maps of the rotated data is the same as rotating the original features.

While any autoencoder can lead to a TICA-based data analysis, there is a specific benefit of using a rotation-equivariant featurization. For example, if our feature map is the identity I, then the TICA encoding of the time series is rotation invariant, i.e. rotating the given data by an orthogonal matrix will not change the TICA embedding. If one performs TICA on rotation-equivariant feature, then one can likewise prove rotation invariance of the resultant TICA embedding. Previous papers such as this one use a simple feed-forward neural network to determine TICA embedding, which does not satisfy equivariance. Therefore, our modification will lead to a better tool for latent space embedding of protein conformation.

TICA embedding on original data (green) versus rotated data (red). The rotated system has the same TICA score up to a sign flip (due to SVD being unique up to a sign difference).

In practice, equivariant neural networks are implemented using various techniques, such as group convolutions, weight sharing, and symmetry-preserving pooling. These techniques ensure that the network parameters are equivariant to transformations (up to their dimension) and that the network output remains invariant to the choice of reference frame.

Dataset and preprocessing

Dataset

The MD simulation dataset for the pentapeptide we are using is hosted and publically available on the MDShare website. This data contains two parts: 1 PDB file describing the pentapeptide’s topology (atoms, residues, bonds), and trajectory files (xyz coordinates for each atom over time). We can use MDTraj and PyMOL to further inspect and visualize the topology.

import mdshare
import mdtraj as md

pdb = mdshare.fetch('pentapeptide-impl-solv.pdb', working_directory='data')
files = mdshare.fetch('pentapeptide-*-500ns-impl-solv.xtc', working_directory='data')

topology = md.load(pdb).topology
print(f'Each trajectory contains {topology.n_atoms} atoms and {topology.n_residues} residues.')
res_seq = []
res_dict = {'TRP':'W', 'LEU':'L', 'ALA':'A'}
for res in topology.residues:
res_seq.append(res_dict[res.name])
print('Residue sequence: ', res_seq)
traj_len = len(md.load_xtc(files[0], top=pdb))
print('Traj length: ', traj_len)

Running this, we can see that the pentapeptide has 94 atoms and 5 amino acids (residues), with the amino acid sequence [‘TRP’, ‘LEU’, ‘ALA’, ‘LEU’, ‘LEU’]. Each trajectory has 5001 frames.

Pentapeptide visualization using PyMOL. Each atom has a label of what residue it belongs to. The sequence starts from right to left.

In total, this pentapeptide dataset consists of 25 trajectories, each with 5001 frames (25 x 5001=125025 total data points), each of which contains 3D coordinates of 94 atoms in 5 residues. As we mentioned previously, we are only concerned with the 4 atoms in each amino acid: (nitrogen, carbon, carbon-ɑ, and oxygen), as the other parts (including sidechains) can move around more freely and will not contribute as meaningfully (i.e. kinetically) to the overall conformation. This gives 4 atom types x 5 residues = 20 atoms at each data point (and thus 20 atoms x 3 coordinates = 60 total values for each data point).

An example of why we use the 4 atom types can be seen in the figure below, where we look at TICA embeddings for one trajectory (i.e. just 5001 frames). On the left, we have TICA computed using all of the atoms and on the right we have TICA computed just using the four atoms defined above. When we just use the backbone, we can see that there are two more heavily populated clusters, in comparison to one main populated area on the left, indicating coordinates that give values in the two clustered areas could represent two conformations in the data that are consistently kinetically different.

TICA representations from trajectory 0. Left: TICA computed using all coordinates. Right: TICA computed using just the backbone atom coordinates.

When constructing the dataset, we add both the coordinates and TICA information into a structures dictionary format, along with the name (given by the trajectory number and frame number), and the residue sequence. A plot of the first 2 dimensions of TICA over all (25 x 5001 = 125025) coordinates can be seen in the results section.

Graph Vector Network

A lot of this is inspired heavily by the Equivariant GVP-GNNs described here:

Geometric Vector Perceptron’s (GVP) GNNs describe networks that can utilize atomic coordinate data for problems like protein design and protein model assessment. We extend it to work on data from MD simulation to predict conformational ranges of proteins. We further modified the GVP-GNN structure to instead just look at vector features and corresponding layers (rather than a combination of vector and scalar featurizations and layers). We do this because the relevant information here regarding protein’s structure can be seen just with the features coming from coordinate (vector) data. We also do this to have a more simple, understandable graph network for learning purposes.

Node/ edge featurization: The features taken from the structures are coordinates centered at the carbon-ɑ atom consisting of the orientation of this atom as well as the sidechain distances which include the other nitrogen, carbon, and oxygen atoms. We also use edge features as an input to some layers (although we do not update these values through training), and featurize them using distances between neighboring nodes found through torch_cluster.knn_graph, which finds graph edges to the nearest k points.

Data loading: We use PyG’s DenseDataLoader to form batches in the way we want (stacking attributes in a new dimension), which we do to differentiate between types of features as we want to include other information such as TICA values for each datapoint. DataLoaders are useful as they do the batching automatically.

GraphVector — GraphConv — GraphConvLayer — GraphEncode

We define our general GraphVector layers as several nn.Linear layers followed by a torch.sigmoid activation.

class GraphVector(nn.Module):
'''
Adapted from GVP from B Jing, S Eismann, et al.
'''
...
# Initialization omitted for brevity
...
def forward(self, x):
x = torch.transpose(x, -1, -2)
xh = self.Wh(x) # Linear layer
x = self.Wv(xh) # Linear layer
x = torch.transpose(x, -1, -2)
if self.activation:
x = x * self.activation(_norm_no_nan(x, axis=-1, keepdims=True)) # sigmoid activation
return x

We then stack several of these layers equentially to form a graph convolution, which we split (following GVP-GNN stylistically) into GraphConv (which computes the aggregated messages by concatenating attributes from source nodes, edge attributes, and target nodes) and GraphConvLayer (which computes the residual updates and feed-forward layers).

class GraphConv(MessagePassing):
'''
Adapted from GVPConv from B Jing, S Eismann, et al.
'''
...
# Initialization omitted for brevity
...
def forward(self, x, edge_index, edge_attr):
out = self.propagate(edge_index, x=x.reshape(x.shape[0], 3*x.shape[1]), edge_attr=edge_attr)
out = out.view(out.shape[0], out.shape[-1]//3,3)
return out

def message(self, x_i, x_j, edge_attr):
x_j = x_j.view(x_j.shape[0], x_j.shape[1]//3, 3)
x_i = x_i.view(x_i.shape[0], x_i.shape[1]//3, 3)
out = torch.concatenate((x_j, edge_attr, x_i), dim=1)
out = self.conv_layers(out) # conv_layers is a sequence of GraphVector layers
out = torch.reshape(out, out.shape[:-2] +(3*out.shape[-2],))
return out

class GraphConvLayer(torch.nn.Module):
'''
Adapated from GVPConvLayer from B Jing, S Eismann, et al.
'''
...
def forward(self, x, edge_index, edge_attr):
out = self.conv(x, edge_index, edge_attr) # conv is a GraphConv layer
x = self.norm[0](x + self.dropout[0](out)) # LayerNorm and Dropout

out = self.forward_layers(x) # sequence of GraphVector layers
x = self.norm[1](x + self.dropout[1](out)) # LayerNorm and Dropout
return x

As you can see, this is where we define our MessagePassing layer, which is how a GNN layer is specified using PyG. To use a MessagePassing layer for a GNN, you need to define [1] the message that will be propagated over nodes through edges and [2] the way to aggregate messages across neighbors into one source node. We use a standard mean aggregation (passed in to the constructor) and our messages consist of source/target node attributes concatenated with the edge attributes. It is crucial that this aggregation step allows the network to remain invariant (or equivariant), which mean aggregation does.

Lastly, we have an encoder GraphEncode which transforms node and edge embeddings into a hidden dimension using GraphVector layers, and then encodes them using a series of GraphConvLayer layers. It additionally squeezes these embeddings to a small latent space, which we try to mimic as the TICA representation.

class GraphEncode(torch.nn.Module):
'''
Encoder that uses GraphConv as encoder layers
Takes in protein structure graphs
'''
...
# Initialization omitted for brevity
...

def forward(self, node_embs, edge_index, edge_embs):
node_embs = node_embs.reshape(node_embs.shape[0]*node_embs.shape[1],
node_embs.shape[2], node_embs.shape[3])

edge_embs = edge_embs.reshape(edge_embs.shape[0]*edge_embs.shape[1],
edge_embs.shape[2], edge_embs.shape[3])

node_embs = self.W_node(node_embs) # GraphVector layer and LayerNorm
edge_embs = self.W_edge(edge_embs) # GraphVector layer and LayerNorm

for layer in self.encode_layers: # Sequence of GraphConvLayer layers
node_embs = layer(node_embs, edge_index, edge_embs)

encoded = node_embs.reshape(node_embs.shape[0]//self.node_num, self.node_num, -1, 3)
flat_node_embs = node_embs.reshape(node_embs.shape[0]//self.node_num, -1)
flat_node_embs = self.squeeze_layer(flat_node_embs) # Sequence of Linear, Relu, Dropout, Linear layers

return encoded, flat_node_embs

Loss: As a reminder, we want to make our latent space somehow represent similar information that TICA does. One way to think about this is that if two frames are close in TICA space, then we want them to also be close in this latent space. We doo this by minimizing the difference between the pairwise cosine similarity scores of the two.

...
tica = batch.tica
emb, latent = forward(nodes, edge_index, edges)
loss = ((pairwise_cos_sim(tica) - pairwise_cos_sim(latent))**2).mean()

Results

We trained for 200 epochs, using a batch size of 512.

Left: First 2 dimensions of expanded TICA representation over all 25x5001 coordinates. Right: Latent space representation of test values on trained model

After training we can see that the first two dimensions of our latent space do not exactly replicate the corresponding TICA dimension. However, as we trained on the pairwise cosine similarity, this can be okay. As seen in the code, the latent representations and TICA representations have different dimensionality. Therefore, as we can see several clusters on the TICA representation using the first two dimensions, it is not necessarily wrong that we cannot see the same using only the first two dimensions of the latent space as it has more dimensionality to be expressive in.

One way to verify the results is to look at one of the samples that gives a latent value in its high-density area, and we can then look at that trajectory’s TICA map to see if it is also in a high-density area. Looking at the figure below, this does work as we expected:

Left: Latent space representation. Red dot is a sample (trajectory=17, frame=1050) pair in a high density region. Center: TICA representation using all trajectories and the highlighted sample is still in a high density region. Right: TICA representation for trajectory 17 where frame 1050 is highlighted, also in a high density region.

We note that all TICA and latent space representation images are just showing 2 dimensions, but they are 4- and 16-dimensional, respectively. These are parameters that we chose, and picking the “best” values for this can be studied further. We did not want to make the latent space representation too “small” because it could lose information about the underlying graph structure.

If the example point coincides with a high-density region in all three of these embeddings, then it makes sense that the latent representation does in fact contain the information about the polypeptide’s conformational landscape. Not only that, but it also uses information from the coordinate position in the node features to develop similarities, indicating that its representation could be more expressive than the original TICA representation.

Conclusion

Future work can be done on this latent representation for other larger molecules, such as proteins. We can also hope to use a method like this to sample example coordinates for a specified conformation identified in the latent representation. For example, if a latent representation has two high density clusters, samples from either clusters can be taken. Those corresponding coordinates from the samples can be used as two different conformations of the protein. This is just one way to explore this representation as a conformational range.

In this blog post, we show how one can use a graph neural network (GNN) to embed time series location data of atoms into a latent space. We demonstrate that GNNs can be used as with atomic coordinates and a featurization technique which leads to a rotation invariant embedding. There is a lot of excitement in the use of GNNs for scientific questions, and we are excited to add our input to it. Hope you enjoy this blog post!

--

--