BERT2Mult: Predicting “a priori” protein-protein interactions with graph neural networks and language models

By Julian Perez, Alejandro Lozano, and Arjun Rajan as part of the Stanford CS224W course project.

Introduction

Proteins are essential for the functioning of all biological life, performing a wide range of tasks within cells, from enzymes that catalyze biochemical reactions to antibodies that recognize pathogen-infected cells. Protein-protein interactions (PPIs) play a crucial role in many of these processes, as proteins must bind and interact with other objects to achieve their purpose. However, predicting PPIs is a biologically complex problem that remains difficult to test in laboratory settings, and computational approaches are still in development. Recent advances in machine learning have enabled researchers to make significant strides in predicting protein folding, but accurately predicting PPIs remains a highly sought-after goal.

One reason to develop models that accurately predict PPIs involves leveraging previously validated interactions from well-studied organisms to predict interactions in newly sequenced organisms. Predicting PPIs a priori is especially important as researchers seek to understand previously uncharacterized organisms and the diversity of life more broadly. In this blog post, we describe a novel methodology for predicting PPIs in organisms using limited or no known PPIs from the same organism. We combine the power of Graph Neural Networks (GNN) with Large Language Models (LLM) to create accurate PPI predictions.

We tested two separate GNN architectures, GraphSAGE [1] and DistMult [2], and found that, in this application, DistMult provides higher accuracy predictions while also being able to make predictions a priori. Furthermore, we found that using LLMs in conjunction with GNNs improves the accuracy of PPI predictions compared to LLMs alone. We believe that our work represents a meaningful advancement in the way PPIs are predicted and hope that it will contribute to future developments in this important field.

Our processed dataset and trained model areavailable on GitHub and you can follow try it for yourself with our Collab notebook.

Datasets & Task Description

For this task, we will use a subset of the STRING dataset [3] which contains precomputed functional links between proteins. The STRING dataset provides a graphical representation of the network of inferred, weighted PPIs (high-level view of functional linkage), facilitating the analysis of modularity in biological processes. It is estimated that the database predicts functional interactions at an average accuracy of at least 80% for more than half of the genes. Furthermore, STRING is updated continuously, and currently contains 261,033 orthologs in 89 fully sequenced genomes. We specifically focus on using data from the STRING dataset that comes from biologically validated experiments, to ensure that the data we are training on and predicting is from real PPIs, rather than predicted PPIs.

A small subset of a protein-protein interaction network visualized from the STRING database, where color saturation represents confidence. Image source: https://en.wikipedia.org/wiki/STRING

From the STRING dataset, we can subset our training and testing datasets such that the PPIs come from separate organisms. For this application, we used PPIs from 65 distinct bacterial species for training and 4 distinct bacterial species for testing. We can examine a subset of a single organism in our data using the code below.

edge_index = torch.tensor(train_dataset[0][1]).T # The Edge index defines the PPI in every graph  size:  2    x |E|
X = train_dataset[0][0] # X represents the nodes feature matrix size: |V_j| x d

# Let's get the cardinality of the edge and node set (used for some statistics)
V_j = X.shape[0]
E_j = edge_index.shape[1]

data_ = Data(x=X , edge_index=edge_index ) # We create a PyG Data object with our data
print(data_) # Lets visualize this data object
print(f"This organism has {V_j} proteins and {E_j} PPIs ({2 * 100 E_j_car/(E_j_car*E_j_car):1f} % interactions) ")

This shows that for this particular organism, we have 3440 proteins in the dataset, and 86904 PPIs, which means makes up 0.0024% of possible interactions. We can further visualize a subgraph of this organism’s data using PyG’s interactivity with NetworkX, to see what a PPI graph looks like.

Nodes_to_plot = 30                                                                                # Number to plot nodes
G = torch_geometric.utils.to_networkx(data_, to_undirected=True) # Convert pyg data to networkx
random_nodes = np.random.randint(low=1, high=V_j, size=Nodes_to_plot, dtype=int).tolist() # Select random nodes
random_ppis = np.random.randint(low=1, high=E_j, size=Nodes_to_plot//10, dtype=int).tolist() # selectrandom nodes
ppi_nodes = edge_index[:,random_ppis].unique().tolist() # Select nodes with PPI's
random_nodes.extend(ppi_nodes) # appends lists
H = G.subgraph(random_nodes) # Get node induced subgraph

## Define color maps to easily distinguish PPI
color_map = []
for node in H:
if node in ppi_nodes:
color_map.append('blue')
else:
color_map.append('green')


fig, ax = plt.subplots(1, 2,figsize=(10,4))
pos = nx.circular_layout(H, scale=1, center=None, dim=2)
nx.draw(H,node_color=color_map,with_labels=False,pos=pos, ax=ax[0])
ax[0].set_title("Circular Layout")

pos = nx.random_layout(H)
nx.draw(H,node_color=color_map,with_labels=False,pos=pos,ax=ax[1])
ax[1].set_title("Random Layout")
plt.show()
PPI subgraph from a single organism where blue nodes indicate proteins with characterized PPIs.

One aspect that is still missing from our data is the negative interactions, that is, edges that represent proteins that do not interact with each other. This is essential for training, otherwise, our model will learn to just predict positive edges between each protein pair. We can use PyG’s negative_sampling to add in these negative edges. For example

positive_ppi = torch.tensor([[3, 0, 0, 3],
[2, 3, 2, 1]])

negative_ppi = negative_sampling(edge_index = positive_ppi,
num_nodes = 10,
num_neg_samples = 6,
method = 'dense',
force_undirected = True)

print("positive edges:")
print(negative_ppi)

print()

print("negative edges:")
print(negative_ppi)

we can generate negative edges on a small playset of nodes, and ensure that the negative edges that we generate do not overlap with the positive edges. With this, we now have a representation of what the graph of PPIs should look like, which means we can train a machine-learning model.

Our task is to predict PPIs in species where they are currently unknown. More specifically, we want to be able to predict a priori what proteins interact with what proteins without having any knowledge of the protein interactions within the species. In this case, the only information we can rely on is the protein sequences: the set of 20 amino acids that form a polypeptide chain to make a protein. For each species, we will have a set of strings that represent the protein sequences within that species, and these will be represented as nodes. Then, we need to predict edges between these nodes, where the edges represent PPIs. With this in mind, our first subtask is to gain some additional information about the proteins we are working with.

Using Large Language Models to embed proteins in low dimensional space

When using GNNs to predict edges, we first need to embed the nodes of our graph into a feature space. Each node should have features that contain some information about the node (in this case, information about the protein). Then, the GNN can learn which nodes should be connected by edges. But how can we get feature information for our nodes if all we have is a string representing their amino acid sequence?

The ProteinBERT architecture. ProteinBERT was trained on ~106 million well-documented proteins (UniRef90). We can use the pre-trained model to put in input-sequences and get out annotations (embeddings). Image source: https://doi.org/10.1093/bioinformatics/btac020

To solve this, we used ProteinBERT [4], a deep language model specifically designed for proteins. The pre-training scheme combines language modeling with a novel task of Gene Ontology (GO) annotation prediction which provides novel architectural elements that make the model highly efficient and flexible to long sequences. ProteinBERT can create a 1024 dimension feature vector given a protein sequence, which encodes proteins into a fixed-size low-dimensional embedding space. This model can be used to obtain an embedding for a protein with sequence alone, meaning we do not need any additional information about the protein or the species it comes from. Thus given the protein sequences, we can use the pre-trained ProteinBERT model to embed the nodes in our graph for training. With our embeddings, we are now ready to start training our GNN to predict PPIs.

BERT2Sage: Predicting PPIs by applying GraphSAGE to ProteinBERT embeddings

Before diving into the specific GNN we chose, we should first describe the general principles of GNNs. In a GNN, the idea is to develop node embeddings for each node based on the local neighborhood. In this case, we take the node embeddings, and for our target node that we want to improve the embedding for, we AGGREGATE information across the set of the target node’s neighbor’s embeddings through K-layers (representing the number of hops away from the target node we aggregate across). So each GNN layer has a 2-step process, the first being message computation, where we bass the layer-k embeddings of the target node through a message function, and then the second step being aggregation, where we aggregate the messages from the neighbors of the target node. The result of this aggregation is then used to update the embedding of the target node. Overall, this process provides a way for the model to learn improved embeddings for nodes based on the properties of the surrounding nodes.

For our initial GNN, we chose to implement GraphSAGE [1]. GraphSAGE learns a representation for nodes based on a combination of its neighboring nodes. In our case, we use 1 layer of GraphSAGE convolution (K=1) for message passing, because PPIs are represented by being 1-hop away from each other, and additional layers could lead to over-smoothening. Additionally, the graph of PPIs is quite sparse, with less than 1% of nodes in a given graph interacting, so additional layers would provide minimal benefit. We chose GraphSAGE for our model because GraphSAGE is an inductive framework, which means that it can leverages node attribute information to efficiently generate representations on previously unseen data.

We can use PyG to implement GraphSAGE:

class GNNStack(torch.nn.Module):
def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):
"""
A stack of GraphSAGE Module
input_dim <int>: Input dimension
hidden_dim <int>: Hidden dimension
output_dim <int>: Output dimension
layers <int>: Number of layers
dropout <float>: Dropout rate
return_embedding <bool>: Whether to return the return_embeddingedding of the input graph
"""

super(GNNStack, self).__init__()
graphSage_conv = pyg.nn.SAGEConv
self.dropout = dropout
self.layers = layers
self.return_embedding = return_embedding

### Initialize the layers ###
self.convs = nn.ModuleList() # ModuleList to hold the layers
for l in range(self.layers):
if l == 0:
### First layer maps from input_dim to hidden_dim ###
self.convs.append(graphSage_conv(input_dim, hidden_dim))
else:
### All other layers map from hidden_dim to hidden_dim ###
self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

# post-message-passing processing MLP
self.post_mp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Dropout(self.dropout),
nn.Linear(hidden_dim, output_dim))

def forward(self, x, edge_index):
for i in range(self.layers):
x = self.convs[i](x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)

x = self.post_mp(x)

# Return final layer of return_embeddingeddings if specified
if self.return_embedding:
return x

# Else return class probabilities
return F.log_softmax(x, dim=1)

def loss(self, pred, label):
return F.nll_loss(pred, label)


### We will use This function to save our best model during trainnig ###
def save_torch_model(model,epoch,PATH:str,optimizer):
print(f"Saving Model in Path {PATH}")
torch.save({'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer':optimizer,
}, PATH)

One big question that remains is how can we use the embeddings we have now worked hard to generate, through ProteinBERT followed by GraphSAGE to learn PPIs. In our case, we used a Multilayer perceptron (MLP) that then can take the embeddings of two nodes of a given graph, z_i and z_j, and compute the element-wise product across the two nodes to get a probability that the two proteins can interact. We chose element-wise multiplication as it is commutative, meaning the order of the two input proteins does not matter.

We can initialize our MLP on the embeddings of our proteins which we receive from ProteinBERT. Then we have a training loop where we get the current embeddings of our graph, update the embeddings by passing them through our GraphSAGE model, predict the edges using our MLP, calculate the loss function against the known PPIs which we initially got from STRING, then update our GraphSAGE and MLP models. By following this loop, we can get a prediction of how accurate our PPIs are.

BERT2Sage Results

To train our BERT2Sage model, we tested our model given various configurations of starting percentages of known PPIs. This is because GraphSAGE requires at least a partial network to make predictions. All models were trained over 500 epochs. Even at 500 epochs, with 60% of known PPIs given at the start, the model was continuing to learn, suggesting that we could see further improvements if we trained with more epochs.

Train loss and accuracy curves for BERT2Sage with 60% of

As can be seen below, the model improved significantly when it started with more known PPIs. That being said, even with just 10% known PPIs, we could achieve a fantastic model performance with a sensitivity of 85.9% and specificity of 90.84%. This demonstrates that with just a little bit of knowledge about current PPIs in the organism being studied, GNNs paired with LLMs can provide very accurate predictions of PPIs.

BERT2Sage output with 1% of starting PPIs known.
BERT2Sage output with 10% of starting PPIs known.
BERT2Sage output with 50% of starting PPIs known.

Caveats of the BERT2Sage model

There are two potential issues with the BERT2Sage model. One is that it could be possible that ProteinBERT alone already provides robust enough embeddings that our MLP can predict PPIs accurately without the need for GraphSAGE. To address this, we trained an MLP directly on the ProteinBERT embeddings and showed that the model is unable to learn PPIs.

MLP trained on ProteinBERT embeddings is unable to accurately predict PPIs.

This demonstrates that GraphSAGE and more generally GNNs provide a valuable step in predicting PPIs by improving the embeddings provided by an LLM to create better predictions.

Another potential issue is that BERT2Sage requires some knowledge of the known PPIs since GraphSAGE needs some initial information to make predictions. This is in direct conflict with our original task, which was to predict PPIs a priori. While BERT2Sage does show impressive prediction capability with just 10% of known PPIs, we endeavored to design an improved model that could work completely a priori. This led to the development of BERT2Mult.

BERT2Mult: Predicting PPIs by applying DistMult to ProteinBERT embeddings

DistMult is a GNN that is based on a knowledge graph where each relation is represented as a low-dimensional vector as well as the embeddings. In this sense, DistMult does not need any initial adjacency matrix, as it can learn how to represent entities and relations in a continuous vector space based on their occurrences in the graphs. DistMult takes as an input a triple (h, r, t) where h is the node representing the head entity, t is the node representing the tail entity, and r is the relation. DistMult learns to predict missing triples by modeling the likelihood of observed triples. We can use then train DistMult using our protein embeddings from ProteinBERT and using our training dataset from the 65 organisms from STRING, and not require any known PPIs for the 4 test organisms since ideally, DistMult will have learned the proper r relations. This framework then is able to accomplish something essential to our task that GraphSAGE was unable to, namely, predicting PPIs a priori. We can implement DistMult using PyG with the following code snippet:

import torch
import torch.nn as nn
import torch.nn.functional as F

class NSDistMultCE(nn.Module):
def __init__(self, in_dim=1024, embedding_dim=200, c_neg=None, regularization=0.002, device='cuda'):
super(NSDistMultCE, self).__init__()
self.in_dim = in_dim # Dimenstion of the bert embedding
self.embedding_dim = embedding_dim # Lower-dimension used by DistMULT
self.c_neg = c_neg # Proportion of negative cases relative to positive, if None it is calculated for each organism
self.regularization = regularization # Regularization parameter used to keep L2 of parameters from exploding
self.reduction = torch.sum # Reduction used to calculate total loss
self.device = device # Device used to hold data and perform calculations

self.linear_emb = nn.Linear(in_dim, embedding_dim, bias=False, device=self.device) # Linear transformation that transforms original embedding into lower-dimension
self.tanh = nn.Tanh() # Non-linear function used to project original embeddings
self.linear_score = nn.Linear(embedding_dim, 1, bias=False, device=self.device) # Linear transformation for the relation PPI
self.sigmoid = nn.Sigmoid() # Nonlinear transformation applied to scores
self.criterion = nn.ReLU() # Used to calculate BCE

def reset_parameters(self):
self.linear_emb.reset_parameters()
self.linear_score.reset_parameters()

def forward(self, embeddings, positive_pairs):
# Transform embeddings to low-dimension and determine c_neg
embeddings = embeddings.squeeze().to(self.device)
low_dim_embeddings = self.embed(embeddings)
c_neg = self.c_neg
if c_neg == None:
c_neg = len(positive_pairs[0]) / (embeddings.shape[0]**2 - len(positive_pairs[0]))

# Calculate L_p
positive_scores = self.get_score(low_dim_embeddings, positive_pairs)
L_p = self.reduction(-torch.log(positive_scores) + c_neg * torch.log(1 - positive_scores), dtype=float)
del positive_scores

# Calculate L_
low_dim_embeddings_r = torch.einsum('i,ij->ij', self.linear_score.weight.squeeze(), low_dim_embeddings.T)
all_scores = self.sigmoid(torch.mm(low_dim_embeddings, low_dim_embeddings_r))
L_a = -c_neg * self.reduction(torch.log(1 - all_scores), dtype=float)
del low_dim_embeddings_r
del all_scores

return self.loss(L_p, L_a)

def embed(self, embeddings):
# Transforms original embedding into lower-dimension
return self.tanh(self.linear_emb(embeddings))

def predict(self, embeddings, pairs):
# Get score for given pairs in a given high-dimension embedding
low_dim_embeddings = self.embed(embeddings.to(self.device))
scores = self.get_score(low_dim_embeddings, pairs)
return scores

def get_score(self, embeddings, pairs):
# Get score for given pairs in a given low-dimension embedding
heads = embeddings[pairs[0]]
tails = embeddings[pairs[1]]
raw_scores = self.linear_score(heads * tails)
return self.sigmoid(raw_scores)

def loss(self, L_p, L_a):
# Calculate BCE Loss
return self.criterion((L_p + L_a) + self.regularization * (torch.norm(self.linear_score.weight) + torch.norm(self.linear_emb.weight)))

In this case, we do not need an MLP, since DistMult itself is training on relations, so it is also able to predict the relations between proteins, rather than just updating the protein embedding. The training loop is therefore relatively straightforward.

BERT2Mult Results

We found that after training with 1200 epochs we found that BERT2Mult had an accuracy of 94.5% and a specificity of 89.4%.

Results of BERT2Mult model with 1200 epochs, 65 training organisms, 4 test organisms.

This demonstrates a small improvement over BERT2Sage and has the benefit of being completely a priori. Thus, BERT2Mult demonstrates the power of combining LLMs and GNNs together in order to predict PPIs in previously uncharacterized organisms.

Conclusions and Future Work

We hope that the work demonstrated here provides inspiration to combine LLMs with GNNs in order to make predictions between complex objects. We found that BERT2Mult was able to accurately predict PPIs completely a priori. We hope in the future to train BERT2Mult on a larger dataset, such that it can be used as a tool by researchers to search for PPIs in newly sequenced organisms, without the need for any known information from the organism. Future will should also work to benchmark BERT2Mult against other state-of-the-art algorithms, such as AlphaFold2 [5] to better assess the advancement made by BERT2Mult.

References

[1] Hamilton, W., Ying, R. & Leskovec, L. Inductive Representation Learning on Large Graphs. arXiv, (2017). https://doi.org/10.48550/arXiv.1706.02216

[2] Yang, B., Yih, W., He, X., Gao, J. & Deng, L. Embedding Entities and Relations for Learning and Inference in Knowledge Bases. arXiv, (2014).
https://doi.org/10.48550/arXiv.1412.6575

[3] Szklarczyk, D., Gable, A. L., Lyon, D., Junge, A., Wyder, S., Huerta-Cepas, J., Simonovic, M., Doncheva, N. T., Morris, J. H., Bork, P., Jensen, L. J., & Mering, C. String V11: Protein–protein association networks with increased coverage, supporting functional discovery in genome-wide experimental datasets. Nucleic Acids Research, 47(D1), (2018). https://doi.org/10.1093/nar/gky1131

[4] Brandes, N., Ofer, D., Peleg, Y., Rappoport, N., & Linial, M. Proteinbert: A universal deep-learning model of protein sequence and function. Bioinformatics, 38(8), 2102–2110, (2022). https://doi.org/10.1093/bioinformatics/btac020

[5] Bryant, P., Pozzati, G. & Elofsson, A. Improved prediction of protein-protein interactions using AlphaFold2. Nat Commun 13, 1265 (2022). https://doi.org/10.1038/s41467-022-28865-w

--

--