# BERT2Mult: Predicting “a priori” protein-protein interactions with graph neural networks and language models

--

*By Julian Perez, Alejandro Lozano, and Arjun Rajan as part of the Stanford CS224W course project.*

## Introduction

Proteins are essential for the functioning of all biological life, performing a wide range of tasks within cells, from enzymes that catalyze biochemical reactions to antibodies that recognize pathogen-infected cells. Protein-protein interactions (PPIs) play a crucial role in many of these processes, as proteins must bind and interact with other objects to achieve their purpose. However, predicting PPIs is a biologically complex problem that remains difficult to test in laboratory settings, and computational approaches are still in development. Recent advances in machine learning have enabled researchers to make significant strides in predicting protein folding, but accurately predicting PPIs remains a highly sought-after goal.

One reason to develop models that accurately predict PPIs involves leveraging previously validated interactions from well-studied organisms to predict interactions in newly sequenced organisms. Predicting PPIs *a priori *is especially important as researchers seek to understand previously uncharacterized organisms and the diversity of life more broadly. In this blog post, we describe a novel methodology for predicting PPIs in organisms using limited or no known PPIs from the same organism. We combine the power of Graph Neural Networks (GNN) with Large Language Models (LLM) to create accurate PPI predictions.

We tested two separate GNN architectures, GraphSAGE [1] and DistMult [2], and found that, in this application, DistMult provides higher accuracy predictions while also being able to make predictions *a priori*. Furthermore, we found that using LLMs in conjunction with GNNs improves the accuracy of PPI predictions compared to LLMs alone. We believe that our work represents a meaningful advancement in the way PPIs are predicted and hope that it will contribute to future developments in this important field.

Our processed dataset and trained model areavailable on GitHub and you can follow try it for yourself with our Collab notebook.

## Datasets & Task Description

For this task, we will use a subset of the STRING dataset [3] which contains precomputed functional links between proteins. The STRING dataset provides a graphical representation of the network of inferred, weighted PPIs (high-level view of functional linkage), facilitating the analysis of modularity in biological processes. It is estimated that the database predicts functional interactions at an average accuracy of at least 80% for more than half of the genes. Furthermore, STRING is updated continuously, and currently contains 261,033 orthologs in 89 fully sequenced genomes. We specifically focus on using data from the STRING dataset that comes from biologically validated experiments, to ensure that the data we are training on and predicting is from real PPIs, rather than predicted PPIs.

From the STRING dataset, we can subset our training and testing datasets such that the PPIs come from separate organisms. For this application, we used PPIs from 65 distinct bacterial species for training and 4 distinct bacterial species for testing. We can examine a subset of a single organism in our data using the code below.

`edge_index = torch.tensor(train_dataset[0][1]).T # The Edge index defines the PPI in every graph size: 2 x |E|`

X = train_dataset[0][0] # X represents the nodes feature matrix size: |V_j| x d

# Let's get the cardinality of the edge and node set (used for some statistics)

V_j = X.shape[0]

E_j = edge_index.shape[1]

data_ = Data(x=X , edge_index=edge_index ) # We create a PyG Data object with our data

print(data_) # Lets visualize this data object

print(f"This organism has {V_j} proteins and {E_j} PPIs ({2 * 100 E_j_car/(E_j_car*E_j_car):1f} % interactions) ")

This shows that for this particular organism, we have 3440 proteins in the dataset, and 86904 PPIs, which means makes up 0.0024% of possible interactions. We can further visualize a subgraph of this organism’s data using PyG’s interactivity with NetworkX, to see what a PPI graph looks like.

`Nodes_to_plot = 30 # Number to plot nodes`

G = torch_geometric.utils.to_networkx(data_, to_undirected=True) # Convert pyg data to networkx

random_nodes = np.random.randint(low=1, high=V_j, size=Nodes_to_plot, dtype=int).tolist() # Select random nodes

random_ppis = np.random.randint(low=1, high=E_j, size=Nodes_to_plot//10, dtype=int).tolist() # selectrandom nodes

ppi_nodes = edge_index[:,random_ppis].unique().tolist() # Select nodes with PPI's

random_nodes.extend(ppi_nodes) # appends lists

H = G.subgraph(random_nodes) # Get node induced subgraph

## Define color maps to easily distinguish PPI

color_map = []

for node in H:

if node in ppi_nodes:

color_map.append('blue')

else:

color_map.append('green')

fig, ax = plt.subplots(1, 2,figsize=(10,4))

pos = nx.circular_layout(H, scale=1, center=None, dim=2)

nx.draw(H,node_color=color_map,with_labels=False,pos=pos, ax=ax[0])

ax[0].set_title("Circular Layout")

pos = nx.random_layout(H)

nx.draw(H,node_color=color_map,with_labels=False,pos=pos,ax=ax[1])

ax[1].set_title("Random Layout")

plt.show()

One aspect that is still missing from our data is the negative interactions, that is, edges that represent proteins that do not interact with each other. This is essential for training, otherwise, our model will learn to just predict positive edges between each protein pair. We can use PyG’s negative_sampling to add in these negative edges. For example

`positive_ppi = torch.tensor([[3, 0, 0, 3],`

[2, 3, 2, 1]])

negative_ppi = negative_sampling(edge_index = positive_ppi,

num_nodes = 10,

num_neg_samples = 6,

method = 'dense',

force_undirected = True)

print("positive edges:")

print(negative_ppi)

print()

print("negative edges:")

print(negative_ppi)

we can generate negative edges on a small playset of nodes, and ensure that the negative edges that we generate do not overlap with the positive edges. With this, we now have a representation of what the graph of PPIs should look like, which means we can train a machine-learning model.

Our task is to predict PPIs in species where they are currently unknown. More specifically, we want to be able to predict *a priori* what proteins interact with what proteins without having any knowledge of the protein interactions within the species. In this case, the only information we can rely on is the protein sequences: the set of 20 amino acids that form a polypeptide chain to make a protein. For each species, we will have a set of strings that represent the protein sequences within that species, and these will be represented as nodes. Then, we need to predict edges between these nodes, where the edges represent PPIs. With this in mind, our first subtask is to gain some additional information about the proteins we are working with.

## Using Large Language Models to embed proteins in low dimensional space

When using GNNs to predict edges, we first need to embed the nodes of our graph into a feature space. Each node should have features that contain some information about the node (in this case, information about the protein). Then, the GNN can learn which nodes should be connected by edges. But how can we get feature information for our nodes if all we have is a string representing their amino acid sequence?

To solve this, we used ProteinBERT [4], a deep language model specifically designed for proteins. The pre-training scheme combines language modeling with a novel task of Gene Ontology (GO) annotation prediction which provides novel architectural elements that make the model highly efficient and flexible to long sequences. ProteinBERT can create a 1024 dimension feature vector given a protein sequence, which encodes proteins into a fixed-size low-dimensional embedding space. This model can be used to obtain an embedding for a protein with sequence alone, meaning we do not need any additional information about the protein or the species it comes from. Thus given the protein sequences, we can use the pre-trained ProteinBERT model to embed the nodes in our graph for training. With our embeddings, we are now ready to start training our GNN to predict PPIs.

## BERT2Sage: Predicting PPIs by applying GraphSAGE to ProteinBERT embeddings

Before diving into the specific GNN we chose, we should first describe the general principles of GNNs. In a GNN, the idea is to develop node embeddings for each node based on the local neighborhood. In this case, we take the node embeddings, and for our target node that we want to improve the embedding for, we AGGREGATE information across the set of the target node’s neighbor’s embeddings through K-layers (representing the number of hops away from the target node we aggregate across). So each GNN layer has a 2-step process, the first being message computation, where we bass the layer-k embeddings of the target node through a message function, and then the second step being aggregation, where we aggregate the messages from the neighbors of the target node. The result of this aggregation is then used to update the embedding of the target node. Overall, this process provides a way for the model to learn improved embeddings for nodes based on the properties of the surrounding nodes.

For our initial GNN, we chose to implement GraphSAGE [1]. GraphSAGE learns a representation for nodes based on a combination of its neighboring nodes. In our case, we use 1 layer of GraphSAGE convolution (K=1) for message passing, because PPIs are represented by being 1-hop away from each other, and additional layers could lead to over-smoothening. Additionally, the graph of PPIs is quite sparse, with less than 1% of nodes in a given graph interacting, so additional layers would provide minimal benefit. We chose GraphSAGE for our model because GraphSAGE is an inductive framework, which means that it can leverages node attribute information to efficiently generate representations on previously unseen data.

We can use PyG to implement GraphSAGE:

`class GNNStack(torch.nn.Module):`

def __init__(self, input_dim:int, hidden_dim:int, output_dim:int, layers:int, dropout:float=0.3, return_embedding=False):

"""

A stack of GraphSAGE Module

input_dim <int>: Input dimension

hidden_dim <int>: Hidden dimension

output_dim <int>: Output dimension

layers <int>: Number of layers

dropout <float>: Dropout rate

return_embedding <bool>: Whether to return the return_embeddingedding of the input graph

"""

super(GNNStack, self).__init__()

graphSage_conv = pyg.nn.SAGEConv

self.dropout = dropout

self.layers = layers

self.return_embedding = return_embedding

### Initialize the layers ###

self.convs = nn.ModuleList() # ModuleList to hold the layers

for l in range(self.layers):

if l == 0:

### First layer maps from input_dim to hidden_dim ###

self.convs.append(graphSage_conv(input_dim, hidden_dim))

else:

### All other layers map from hidden_dim to hidden_dim ###

self.convs.append(graphSage_conv(hidden_dim, hidden_dim))

# post-message-passing processing MLP

self.post_mp = nn.Sequential(

nn.Linear(hidden_dim, hidden_dim),

nn.Dropout(self.dropout),

nn.Linear(hidden_dim, output_dim))

def forward(self, x, edge_index):

for i in range(self.layers):

x = self.convs[i](x, edge_index)

x = F.relu(x)

x = F.dropout(x, p=self.dropout, training=self.training)

x = self.post_mp(x)

# Return final layer of return_embeddingeddings if specified

if self.return_embedding:

return x

# Else return class probabilities

return F.log_softmax(x, dim=1)

def loss(self, pred, label):

return F.nll_loss(pred, label)

### We will use This function to save our best model during trainnig ###

def save_torch_model(model,epoch,PATH:str,optimizer):

print(f"Saving Model in Path {PATH}")

torch.save({'epoch': epoch,

'model_state_dict': model.state_dict(),

'optimizer':optimizer,

}, PATH)

One big question that remains is how can we use the embeddings we have now worked hard to generate, through ProteinBERT followed by GraphSAGE to learn PPIs. In our case, we used a Multilayer perceptron (MLP) that then can take the embeddings of two nodes of a given graph, z_i and z_j, and compute the element-wise product across the two nodes to get a probability that the two proteins can interact. We chose element-wise multiplication as it is commutative, meaning the order of the two input proteins does not matter.

We can initialize our MLP on the embeddings of our proteins which we receive from ProteinBERT. Then we have a training loop where we get the current embeddings of our graph, update the embeddings by passing them through our GraphSAGE model, predict the edges using our MLP, calculate the loss function against the known PPIs which we initially got from STRING, then update our GraphSAGE and MLP models. By following this loop, we can get a prediction of how accurate our PPIs are.

## BERT2Sage Results

To train our BERT2Sage model, we tested our model given various configurations of starting percentages of known PPIs. This is because GraphSAGE requires at least a partial network to make predictions. All models were trained over 500 epochs. Even at 500 epochs, with 60% of known PPIs given at the start, the model was continuing to learn, suggesting that we could see further improvements if we trained with more epochs.

As can be seen below, the model improved significantly when it started with more known PPIs. That being said, even with just 10% known PPIs, we could achieve a fantastic model performance with a sensitivity of 85.9% and specificity of 90.84%. This demonstrates that with just a little bit of knowledge about current PPIs in the organism being studied, GNNs paired with LLMs can provide very accurate predictions of PPIs.

## Caveats of the BERT2Sage model

There are two potential issues with the BERT2Sage model. One is that it could be possible that ProteinBERT alone already provides robust enough embeddings that our MLP can predict PPIs accurately without the need for GraphSAGE. To address this, we trained an MLP directly on the ProteinBERT embeddings and showed that the model is unable to learn PPIs.

This demonstrates that GraphSAGE and more generally GNNs provide a valuable step in predicting PPIs by improving the embeddings provided by an LLM to create better predictions.

Another potential issue is that BERT2Sage requires some knowledge of the known PPIs since GraphSAGE needs some initial information to make predictions. This is in direct conflict with our original task, which was to predict PPIs *a priori*. While BERT2Sage does show impressive prediction capability with just 10% of known PPIs, we endeavored to design an improved model that could work completely *a priori*. This led to the development of BERT2Mult.

## BERT2Mult: Predicting PPIs by applying DistMult to ProteinBERT embeddings

DistMult is a GNN that is based on a knowledge graph where each relation is represented as a low-dimensional vector as well as the embeddings. In this sense, DistMult does not need any initial adjacency matrix, as it can learn how to represent entities and relations in a continuous vector space based on their occurrences in the graphs. DistMult takes as an input a triple *(h, r, t)* where h is the node representing the head entity, t is the node representing the tail entity, and r is the relation. DistMult learns to predict missing triples by modeling the likelihood of observed triples. We can use then train DistMult using our protein embeddings from ProteinBERT and using our training dataset from the 65 organisms from STRING, and not require any known PPIs for the 4 test organisms since ideally, DistMult will have learned the proper *r *relations. This framework then is able to accomplish something essential to our task that GraphSAGE was unable to, namely, predicting PPIs *a priori. *We can implement DistMult using PyG with the following code snippet:

`import torch`

import torch.nn as nn

import torch.nn.functional as F

class NSDistMultCE(nn.Module):

def __init__(self, in_dim=1024, embedding_dim=200, c_neg=None, regularization=0.002, device='cuda'):

super(NSDistMultCE, self).__init__()

self.in_dim = in_dim # Dimenstion of the bert embedding

self.embedding_dim = embedding_dim # Lower-dimension used by DistMULT

self.c_neg = c_neg # Proportion of negative cases relative to positive, if None it is calculated for each organism

self.regularization = regularization # Regularization parameter used to keep L2 of parameters from exploding

self.reduction = torch.sum # Reduction used to calculate total loss

self.device = device # Device used to hold data and perform calculations

self.linear_emb = nn.Linear(in_dim, embedding_dim, bias=False, device=self.device) # Linear transformation that transforms original embedding into lower-dimension

self.tanh = nn.Tanh() # Non-linear function used to project original embeddings

self.linear_score = nn.Linear(embedding_dim, 1, bias=False, device=self.device) # Linear transformation for the relation PPI

self.sigmoid = nn.Sigmoid() # Nonlinear transformation applied to scores

self.criterion = nn.ReLU() # Used to calculate BCE

def reset_parameters(self):

self.linear_emb.reset_parameters()

self.linear_score.reset_parameters()

def forward(self, embeddings, positive_pairs):

# Transform embeddings to low-dimension and determine c_neg

embeddings = embeddings.squeeze().to(self.device)

low_dim_embeddings = self.embed(embeddings)

c_neg = self.c_neg

if c_neg == None:

c_neg = len(positive_pairs[0]) / (embeddings.shape[0]**2 - len(positive_pairs[0]))

# Calculate L_p

positive_scores = self.get_score(low_dim_embeddings, positive_pairs)

L_p = self.reduction(-torch.log(positive_scores) + c_neg * torch.log(1 - positive_scores), dtype=float)

del positive_scores

# Calculate L_

low_dim_embeddings_r = torch.einsum('i,ij->ij', self.linear_score.weight.squeeze(), low_dim_embeddings.T)

all_scores = self.sigmoid(torch.mm(low_dim_embeddings, low_dim_embeddings_r))

L_a = -c_neg * self.reduction(torch.log(1 - all_scores), dtype=float)

del low_dim_embeddings_r

del all_scores

return self.loss(L_p, L_a)

def embed(self, embeddings):

# Transforms original embedding into lower-dimension

return self.tanh(self.linear_emb(embeddings))

def predict(self, embeddings, pairs):

# Get score for given pairs in a given high-dimension embedding

low_dim_embeddings = self.embed(embeddings.to(self.device))

scores = self.get_score(low_dim_embeddings, pairs)

return scores

def get_score(self, embeddings, pairs):

# Get score for given pairs in a given low-dimension embedding

heads = embeddings[pairs[0]]

tails = embeddings[pairs[1]]

raw_scores = self.linear_score(heads * tails)

return self.sigmoid(raw_scores)

def loss(self, L_p, L_a):

# Calculate BCE Loss

return self.criterion((L_p + L_a) + self.regularization * (torch.norm(self.linear_score.weight) + torch.norm(self.linear_emb.weight)))

In this case, we do not need an MLP, since DistMult itself is training on relations, so it is also able to predict the relations between proteins, rather than just updating the protein embedding. The training loop is therefore relatively straightforward.

## BERT2Mult Results

We found that after training with 1200 epochs we found that BERT2Mult had an accuracy of 94.5% and a specificity of 89.4%.

This demonstrates a small improvement over BERT2Sage and has the benefit of being completely *a priori*. Thus, BERT2Mult demonstrates the power of combining LLMs and GNNs together in order to predict PPIs in previously uncharacterized organisms.

## Conclusions and Future Work

We hope that the work demonstrated here provides inspiration to combine LLMs with GNNs in order to make predictions between complex objects. We found that BERT2Mult was able to accurately predict PPIs completely *a priori*. We hope in the future to train BERT2Mult on a larger dataset, such that it can be used as a tool by researchers to search for PPIs in newly sequenced organisms, without the need for any known information from the organism. Future will should also work to benchmark BERT2Mult against other state-of-the-art algorithms, such as *AlphaFold2 *[5] to better assess the advancement made by BERT2Mult.

## References

[1] Hamilton, W., Ying, R. & Leskovec, L. Inductive Representation Learning on Large Graphs. *arXiv*, (2017). https://doi.org/10.48550/arXiv.1706.02216

[2] Yang, B., Yih, W., He, X., Gao, J. & Deng, L. Embedding Entities and Relations for Learning and Inference in Knowledge Bases. *arXiv*, (2014).

https://doi.org/10.48550/arXiv.1412.6575

[3] Szklarczyk, D., Gable, A. L., Lyon, D., Junge, A., Wyder, S., Huerta-Cepas, J., Simonovic, M., Doncheva, N. T., Morris, J. H., Bork, P., Jensen, L. J., & Mering, C. String V11: Protein–protein association networks with increased coverage, supporting functional discovery in genome-wide experimental datasets.* Nucleic Acids Research*, 47(D1), (2018). https://doi.org/10.1093/nar/gky1131

[4] Brandes, N., Ofer, D., Peleg, Y., Rappoport, N., & Linial, M. Proteinbert: A universal deep-learning model of protein sequence and function. *Bioinformatics*, 38(8), 2102–2110, (2022). https://doi.org/10.1093/bioinformatics/btac020

[5] Bryant, P., Pozzati, G. & Elofsson, A. Improved prediction of protein-protein interactions using AlphaFold2. *Nat Commun* **13**, 1265 (2022). https://doi.org/10.1038/s41467-022-28865-w