Protein Prediction Using Gene Ontology Embeddings and Deep Learning.

Embedding ontology axioms to tackle the CAFA4 Challenge with DeepGoZero

Adrian Yijie Xu, PhD
GradientCrescent
22 min readJun 7, 2024

--

Introduction

In recent years, there has been a surge of interest in applying deep learning techniques to bioinformatics and proteomics, particularly for predicting protein structure, function, and interactions. An example of this is the AlphaFold family of models by Deepmind, which have achieved unprecedented accuracy in predicting protein structures and protein-ligand interactions. These solutions enhance our understanding of protein interactions, folding patterns, and functional mechanisms, driving advancements in drug discovery, disease research, and synthetic biology.

To evaluate the performance of computational approaches to protein function prediction , the Critical Assessment of Functional Annotation (CAFA) competition has been held biennially as a benchmark. These challenges task participating groups with predicting the function of a large set of proteins, and their predictions are later compared against experimental data. CAFA provides an objective benchmark for assessing the performance of various algorithms, fostering innovation and improvement in protein function prediction technique. Our team from Silo AI recently participated in the CAFA5 challenge held on Kaggle, achieving a ranking of 41/1625. Amongst various approaches, DeepGOZero stood out to us as an innovative and effective approach, as it utilized the axiomatic knowledge of the gene ontologies to the overall prediction process.

Figure 1. An example of utilizing EL Embeddings, a model-theoretic approach, to predict the function of Protein P utilizing a geometric representation and axiomatic relationships, as is done in DeepGOZero. [4]

The objective of this tutorial is to introduce the concept of Gene Ontologies , relevant background context, and how they are used in the DeepGOZero library to enhance the protein function prediction process.

Background

What are Gene Ontologies?

An ontology is a formal representation of a body of knowledge, within a given domain. A domain can be a field or area such as biology. Ontologies usually consist of a set of classes or terms with relations that operate between them.[1,2]

For protein function prediction and DeepGOZero, we specifically refer to gene ontologies (GO), which are concerned with the knowledge related to the function and organization of genes .

Briefly, GOs attempt to describe gene function with respect to three aspects[1,2]:

  • Molecular Functions (MF): the molecular-level activities performed by gene products
  • Cellular Components (CC): the locations relative to cellular structures in which a gene product performs a function.
  • Biological Processes (BP): the larger processes, or ‘biological programs’ accomplished by multiple molecular activities.

Each of these three aspects is an independent domain represented by a root ontology term, with no root term being linked to another. Individual classes, or terms, can trace their parentage to one of these three root terms, although possibly through numerous different parts.[1,2] Terms are defined using relationships to other terms, called axioms, such as a term being a subclass of another (part_of), or being analogous to another (is_a). Moreover, the three root terms do not share a parent term, although certain relationships may exist between the ontologies (a term from one ontology may regulate a term from another).[2]

Functions are described using the Gene Ontology (GO) which is a large ontology with over 50 000 classes.[1] It contains three sub-ontologies: the Molecular Functions Ontology (MFO), the Biological Processes Ontology (BPO) and the Cellular Components Ontology (CCO). The classes in these three ontologies are not independent and stand in formally defined relations that need to be considered while making predictions.[4]

How are GOs used for protein function prediction?

Figure 2. An example ontology, showing the hierarchical relationships of different GO terms.[1]

Experimental identification of protein functions is time consuming and resource intensive. Next-generation sequencing technologies have led to a significant increase of the number of available DNA and protein sequences, thereby amplifying the necessity and challenge of identifying protein functions.[4]

Combined with the results of targeted experimental studies, computational researchers developed methods that analyze and predict protein structures [6,7], determine the interactions between proteins [8,9] and predict protein functions[10] based on the protein amino acid sequences.[4] Several deep learning approaches were developed and applied to protein function prediction [4, 11,12]; however, there are still several challenges that prevent widespread adoption — primarily[4]:

  • It remains challenging to determine how a computational model can learn efficiently from protein sequences alone or together with other sources of information such as the protein structure, interactions or literature.
  • It remains challenging to predict the correct set of functions in the large, complex, unbalanced and hierarchical space of biological functions.
  • The large GO term space combined with poor or sparse annotation availability inherently hinders the learning of computational models.

In particular, the low number or absence of annotations for numerous examples presents a challenge for training ML models to to predict protein functions or to make similarity-based predictions for GO terms. This presents a significant class imbalance challenge: more than 20 000 terms in the GO have fewer than 100 proteins annotated (based on experimental evidence), and these terms are often specific and therefore highly informative. Many GO terms have never been used to characterize the functions of a protein, or only a few proteins have been characterized with a particular function.[4]

Interestingly, many of those terms have been formally defined by reusing other terms and relations using hierarchical Description Logic axioms [5]. Model-theoretic languages such as Description Logics can also be used to express relational knowledge while adding operators that cannot easily be expressed in graph-based form (i.e. quantifiers, negation, conjunction, disjunction).[4] The simplest form of axiom is subclass axioms, such as Ion binding (GO: 0043167) being a SubClassOf Binding (GO: 0005488), which indicates that every instance of Ion binding must also be an instance of Binding .[4,13] These axioms also constrain annotations of proteins; if a protein P has function Ion binding, it will also have the function Binding — such axioms are used in several protein function prediction methods, and exploiting them have been shown to improve performance [10,11]. As GO has around 12 000 definition axioms and more than 500 000 other axioms, these can be utilized to make predictions for proteins with few (<100) or no experimental annotations.[4]

However, there are axioms beyond subclass axioms in the GO (The Gene Ontology Consortium, 2018). For example, the term serine-type endopeptidase inhibitor activity (GO: 0004867) is defined as being equivalent to molecular function regulator (GO: 0098772) that negatively regulates (RO: 0002212) some serine-type endopeptidase activity (GO: 0004252).[4]

Based on this axiom, If we know that a protein P is a molecular function regulator and can negatively regulate serine-type endopeptidase activity, we can predict that the protein exhibits the function serine-type endopeptidase inhibitor activity. This axiom can therefore be used in two ways by a function prediction model[4]:

  • First, it imposes an additional constraint on functions predicted for P, and this constraint may both reduce search space during optimization and improve prediction accuracy.
  • Secondly, the axiom can allow us to predict the function serine-type endopeptidase inhibitor activity for a protein P even if not a single protein seen during training has this function.

Knowledge Graphs Embeddings & Protein Function Prediction

Knowledge graphs are semantic, mutable graph-based network structures, characterized by nodes and edges to represent entities and relations, with ontology structures being an example of a type of knowledge graphs. Knowledge graph embeddings are commonly used to map entities and their relations expressed in a knowledge graph into a vector space, aiming to preserve relational and other semantic information of the underlying graph.[4] These mapped entities in vector space can be used as features for machine learning tasks, or for determining semantic similarity by measuring their relative positions within vector space.[3,4]

Knowledge graph embedding functions are typically trained through optimization with respect to an objective function and, optionally, a set of constraints. However, inference using such approaches is typically limited to the use of existing relationships between entities to infer new relationships, commonly termed the composition of relations.[3,4] Similarly, attempts to utilize “ontology embeddings” aiming to preserve the syntatic or axiomatic logic of the underlying ontologies in vector space for inference purposes have been frustrated by several factors, including[3,5]:

  • The inability to dynamically infer beyond the defined precomputed inferences of the underlying ontology input.
  • The poor expressiveness of the underlying ontology representation language, which is inherited by the resulting embedding.
  • The poor utilization of prior knowledge about the semantics of logical operators such as quantifiers (for all, there exists), negation (not), conjunction (and), and disjunction (or).

To address these limitations, Kulmanov et. al proposed a geometric ontology embedding method EL Embeddings, which are based on the EL++ Description Logic and aim to extend upon the capabilities of a knowledge graph by incorporating the EL++ operators conjunction, existential quantification, and the bottom concept (an always false logical statement).[3,4,5] While a detailed treatment of EL embeddings is beyond the scope of this work, it can be summarized as utilizing a relational embedding model to represent terms as n-balls and relations as vectors to embed ontology semantics into a geometric model. As a reference, a visualization of EL embeddings for an ontology of family members is shown below in Figure 3, illustrating how the relationships of the ontology entities defined through axioms (in this case, mostly limited to indicating a class is a subclass of another) can be visualized in a 2D space effectively.[3]

Figure 3. Visualization of EL embeddings used for a family test case, showing the positions and relationships of various entities in 2D space. Note how for more specific classes the radius is small, while for more broader classes the radius is larger.

DeepGoZero

DeepGoZero (DGZ) is an architecture proposed by Hoehendorf et al., and is an evolution on the earlier DeepGo and DeepGoPlus approaches evaluated in CAFA4.[4]

The aim of DeepGOZero is to use the background knowledge contained in the Description Logic axioms of GO to improve protein function prediction. In particular, DeepGOZero aims to predict functional annotations for ontology terms without training samples using only the GO axioms (zero-shot).[4]

At the core of DeepGOZero are the MLP and the EL-Embedding models. EL Embeddings are used on the input GO terms to generate a n-dimensional space in which GO terms are n-balls and the location and size of the n-balls are constrained by the GO axioms. The ontology axioms are only used during the training phase of DeepGOZero as losses to ensure that the space respects the GO axioms.[4]

We then use a neural network to project proteins into the same space in which we embedded the GO classes and predict functions for proteins by their proximity and relation to GO classes. For a given protein, relevant InterPro domain annotations (linked to protein function), obtained through services such as InterProScan, are fed as a binary vector into a 2-layer MLP to generate a 1024-size embedding.[4]

During the forward pass, the prediction for the output class c is then obtained through the matrix product of the MLP-obtained protein embedding with the translation of the class embedding (an ELEmbedding component), with the class radius added (likewise).

Formulaically, this can be expressed as[4]:

Where,

fn(p) is the embedding of a protein which is generated by the MLP.

fn(hF) is the embedding of the hasfunctional relation.

fn(c) is the embedding of class c.

rn(c) is the radius of class c.

To train the model, DeepGOZero jointly minimizes prediction loss for protein functions and the ELEmbeddings loss that impose constraints on the GO classes.[4] This is done by computing the binary cross-entropy loss between our predictions and the labels and optimizing them together with four normal form losses for ontology axioms from ELEmbeddings to ensure that Description Logic is respected. Normal forms are alternative representations of ontology relationships used in reasoning and inference tasks.[4,5]

Formulaically this can be described as

Where BCELoss refers to binary crossentropy loss, and LNF1, LNF2, LNF3, and LNF4 the respective normal form losses. The exact definitions and conversion processes of these normal form losses is beyond the scope of this post, but all 4 losses work to ensure the distances between the n-balls of different classes in the n-dimensional space respect the axioms in GO.[4,5]

Figure 4. Visualization of the DeepGOZero inference process.[4]

A visualization of the overall inference process is shown in Figure 4. On the left, A protein P is embedded in a vector space using an MLP whereas the right side shows how GO axioms are embedded using the EL Embedding method; the MLP embeds the protein in the same space as the GO axioms. From the location of P we can annotate it with positive regulation of protein kinase B signaling (GO: 0051897). This class is defined as biological regulation (GO: 0065007) and positively regulates (RO: 0002213) some protein kinase B signaling (GO: 0043491). This knowledge allows us to annotate subsequent proteins with GO: 0051897 even if we do not have any training proteins (zero-shot).[4]

Implementation

Data

We downloaded the UniProt/SwissProt Knowledgebase (UniProtKB-SwissProt) (The UniProt Consortium, 2018) version 2021_04 released on September 29, 2021. We filtered all proteins with experimental functional annotations with evidence codes EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC, HTP, HDA, HMP, HGI and HEP. The dataset contains 77 647 reviewed and manually annotated proteins. For this dataset, we use GO released on November 16, 2021. We train and evaluate models for each of the sub-ontologies of GO separately.[4]

For the purpose of the notebook, this data is saved in a Kaggle dataset

Notebook analysis: Original CAFA4 Dataset

In order to understand the workings of DGZ, Let’s go over how the code works in the baseline vanilla notebook. Note that for the purpose of this explanation, we’ve extracted code from the repository files into the notebook — we’ll mention this explicitly when discussing the relevant files in question.

Let’s start by understanding the structure of the data inputs — these have been prepared by Geraseva et al. of the research team[4], containing both inputs and pre-trained models.

Figure 5. Master directory of the Kaggle notebook.

Notice the presence of the GO core ontology files, including a normal-form compatible version, obtainable from the GO consortium. Note how each of the GO categories (BP, CC, MF) have their own set of data and models. We can take a closer look at the BP category as an example.

Figure 6. Expanded BP-aspect specific directory of the Kaggle notebook.

Inspecting the contents, we observe files related to training metrics, trained models, and various post-training data files. However, two files are of particular importance for training.

  • interpros.pkl is the Interproscan data file with sequences and complete set of annotations.
  • terms_zero.pkl contains the aspect-specific GO terms, or the target classes for prediction purposes. In terms of problem complexity, terms_zero_10.pkl contains the least terms, while terms.pkl the most terms of an ontology aspect. These terms include those evaluated in CAFA4.
  • train_data.pkl, valid_data.pkl, test_data.pkl are input datafiles, also used to construct relevant dataframes.

Let’s inspect the structure and contents of the dataframe.

train_df.columns
Index(['index', 'proteins', 'accessions', 'genes', 'sequences', 'annotations',
'string_ids', 'orgs', 'interpros', 'exp_annotations',
'prop_annotations', 'cafa_target', 'zero_annotations'])
Figure 7. Snapshot of the input dataframe.

From the above, we observe columns related to the protein identity (proteins, accessions), characteristics (sequences, interpros), and annotations.

Defining data directories & parameters:

# Data directories and parameters
data_root='/kaggle/input/deepgozero-data/data'
ont='bp' # GO ontology ['bp','mf','cc']
device='cuda:0'
batch_size=37
epochs=256
load=True # Set to False for retraining [True, False],
go_file = f'{data_root}/go.norm' .
model_file = f'{data_root}/{ont}/deepgozero_zero_10.th'
terms_file = f'{data_root}/{ont}/terms_zero_10.pkl'
out_file = f'{data_root}/{ont}/predictions_deepgozero_zero_10.pkl'
  • data_root: directory of input data extracted from data.tar.gz
  • ont: GO ontology aspect of interest, as DGZ model training & inference is aspectspecific.
  • go_file: GO core ontology file, normal-form compatible.
  • model_file: ontology-specific trained DGZ file. Note that in the vanilla dataset, three model files are provided per ontology, covering differing amounts of GO-annotations ( deepgozero.th, deepgozero_zero.th, deepgozero_zero_10.th)
  • terms_file: ontology aspect-specific GO terms-file.
  • out_file: ontology aspect-specific output file directory.

To replicate the Deepgozero component of the original study, we will extract code from the following files in the repository:

  • utils.py: utility methods & ontology class definition.
  • deepgozero.py: main training & inference script, data loading & structuring, and model definition.
  • torch_utils.py: custom tensor dataloader.

Let’s start with utils.py, which contains auxiliary supporting methods, focusing on the Ontology class in particular. The Ontology class is designed to load, manage, and analyze a Gene Ontology (GO) from a file in OBO format. The ontology is structured as a directed acyclic graph where terms have parent-child relationships. This class is utilized to identify related annotations.

# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/utils.py
# Ontology hierarchy handling methods

class Ontology(object):
# Loads data, optinonally with relationships
def __init__(self, filename='data/go.obo', with_rels=False):
self.ont = self.load(filename, with_rels)
self.ic = None
self.ic_norm = 0.0
# Check that term exists in ontology
def has_term(self, term_id):
return term_id in self.ont
# Check and retrieve term
def get_term(self, term_id):
if self.has_term(term_id):
return self.ont[term_id]
return None

# Calculates the information content (IC)
# of each term based on the annotations provided.
# IC measures the specificity of a term within an ontology.
def calculate_ic(self, annots):
cnt = Counter()
for x in annots:
cnt.update(x)
self.ic = {}
for go_id, n in cnt.items():
parents = self.get_parents(go_id)
if len(parents) == 0:
min_n = n
else:
min_n = min([cnt[x] for x in parents])

self.ic[go_id] = math.log(min_n / n, 2)
self.ic_norm = max(self.ic_norm, self.ic[go_id])

# Retrieves the IC value for a given GO term ID.
def get_ic(self, go_id):
if self.ic is None:
raise Exception('Not yet calculated')
if go_id not in self.ic:
return 0.0
return self.ic[go_id]

# Retrieves the normalized IC value for a given GO term ID
def get_norm_ic(self, go_id):
return self.get_ic(go_id) / self.ic_norm

# Loads ontology data from a file into memory.
# It parses the data and organizes it into a dictionary structure
# representing the ontology.
def load(self, filename, with_rels):
ont = dict()
obj = None
with open(filename, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
if line == '[Term]':
if obj is not None:
ont[obj['id']] = obj
obj = dict()
obj['is_a'] = list()
obj['part_of'] = list()
obj['regulates'] = list()
obj['alt_ids'] = list()
obj['is_obsolete'] = False
continue
elif line == '[Typedef]':
if obj is not None:
ont[obj['id']] = obj
obj = None
else:
if obj is None:
continue
l = line.split(": ")
if l[0] == 'id':
obj['id'] = l[1]
elif l[0] == 'alt_id':
obj['alt_ids'].append(l[1])
elif l[0] == 'namespace':
obj['namespace'] = l[1]
elif l[0] == 'is_a':
obj['is_a'].append(l[1].split(' ! ')[0])
elif with_rels and l[0] == 'relationship':
it = l[1].split()
# add all types of relationships
obj['is_a'].append(it[1])
elif l[0] == 'name':
obj['name'] = l[1]
elif l[0] == 'is_obsolete' and l[1] == 'true':
obj['is_obsolete'] = True
if obj is not None:
ont[obj['id']] = obj
for term_id in list(ont.keys()):
for t_id in ont[term_id]['alt_ids']:
ont[t_id] = ont[term_id]
if ont[term_id]['is_obsolete']:
del ont[term_id]
for term_id, val in ont.items():
if 'children' not in val:
val['children'] = set()
for p_id in val['is_a']:
if p_id in ont:
if 'children' not in ont[p_id]:
ont[p_id]['children'] = set()
ont[p_id]['children'].add(term_id)

return ont

# retrieves all ancestors (immediate and distant) of a given term.
def get_anchestors(self, term_id):
if term_id not in self.ont:
return set()
term_set = set()
q = deque()
q.append(term_id)
while(len(q) > 0):
t_id = q.popleft()
if t_id not in term_set:
term_set.add(t_id)
for parent_id in self.ont[t_id]['is_a']:
if parent_id in self.ont:
q.append(parent_id)
return term_set

# Retrieves all terms that are propagated up the ontology hierarchy from the given set of terms.
def get_prop_terms(self, terms):
prop_terms = set()

for term_id in terms:
prop_terms |= self.get_anchestors(term_id)
return prop_terms

# Retrieves immediate parent terms of a given term.
def get_parents(self, term_id):
if term_id not in self.ont:
return set()
term_set = set()
for parent_id in self.ont[term_id]['is_a']:
if parent_id in self.ont:
term_set.add(parent_id)
return term_set

# Retrieves all terms belonging to a specified namespace.
def get_namespace_terms(self, namespace):
terms = set()
for go_id, obj in self.ont.items():
if obj['namespace'] == namespace:
terms.add(go_id)
return terms

# Retrieves the namespace of a given term.
def get_namespace(self, term_id):
return self.ont[term_id]['namespace']

# Retrieves all terms in the subgraph rooted at the given term.
def get_term_set(self, term_id):
if term_id not in self.ont:
return set()
term_set = set()
q = deque()
q.append(term_id)
while len(q) > 0:
t_id = q.popleft()
if t_id not in term_set:
term_set.add(t_id)
for ch_id in self.ont[t_id]['children']:
q.append(ch_id)
return term_set

Next, we move to inspect code from deepgozero.py. These includes the ROC-AUC calculations, normal form calculations, as well as key data loading methods.

# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/deepgozero.py

def compute_roc(labels, preds):
# Compute ROC curve and ROC area for each class
fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
roc_auc = auc(fpr, tpr)

return roc_auc

# Loads normal forms (NF) from a Gene Ontology (GO) file.
def load_normal_forms(go_file, terms_dict):
nf1 = []
nf2 = []
nf3 = []
nf4 = []
relations = {}
zclasses = {}

def get_index(go_id):
if go_id in terms_dict:
index = terms_dict[go_id]
elif go_id in zclasses:
index = zclasses[go_id]
else:
zclasses[go_id] = len(terms_dict) + len(zclasses)
index = zclasses[go_id]
return index

def get_rel_index(rel_id):
if rel_id not in relations:
relations[rel_id] = len(relations)
return relations[rel_id]

with open(go_file) as f:
for line in f:
line = line.strip().replace('_', ':')
if line.find('SubClassOf') == -1:
continue
left, right = line.split(' SubClassOf ')
# C SubClassOf D
if len(left) == 10 and len(right) == 10:
go1, go2 = left, right
nf1.append((get_index(go1), get_index(go2)))
elif left.find('and') != -1: # C and D SubClassOf E
go1, go2 = left.split(' and ')
go3 = right
nf2.append((get_index(go1), get_index(go2), get_index(go3)))
elif left.find('some') != -1: # R some C SubClassOf D
rel, go1 = left.split(' some ')
go2 = right
nf3.append((get_rel_index(rel), get_index(go1), get_index(go2)))
elif right.find('some') != -1: # C SubClassOf R some D
go1 = left
rel, go2 = right.split(' some ')
nf4.append((get_index(go1), get_rel_index(rel), get_index(go2)))
return nf1, nf2, nf3, nf4, relations, zclasses

# Create root ontology-specific dictionaries of GO terms and interpro annots
def load_data(data_root, ont, terms_file):
terms_df = pd.read_pickle(terms_file)
terms = terms_df['gos'].values.flatten()
terms_dict = {v: i for i, v in enumerate(terms)}
print('Terms', len(terms))

ipr_df = pd.read_pickle(f'{data_root}/{ont}/interpros.pkl')
iprs = ipr_df['interpros'].values
iprs_dict = {v:k for k, v in enumerate(iprs)}
return iprs_dict, terms_dict

# Reference dictionaries against input dataframe to build input tensor pairs
def get_data(df, iprs_dict, terms_dict):
data = th.zeros((len(df), len(iprs_dict)), dtype=th.float32)
labels = th.zeros((len(df), len(terms_dict)), dtype=th.float32)
for i, row in enumerate(df.itertuples()):
for ipr in row.interpros:
if ipr in iprs_dict:
data[i, iprs_dict[ipr]] = 1
for go_id in row.prop_annotations: # prop_annotations for full model
if go_id in terms_dict:
g_id = terms_dict[go_id]
labels[i, g_id] = 1
return data, labels

The two key methods here are load_data and get_data.

  • The load_data takes in ontology aspect-specific terms and interpro input files to construct their respective dictionaries, which are then fed into the latter.
  • get_data serves to cross reference these dictionaries against the input dataframe, in order to create dataframe-specific input and label tensors, which are then fed into dataloader.

Next, we move to inspect the model architecture, also defined in deepgozero.py

# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/deepgozero.py
# Model Structure
class Residual(nn.Module):

def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x):
return x + self.fn(x)


class MLPBlock(nn.Module):

def __init__(self, in_features, out_features, bias=True, layer_norm=True, dropout=0.1, activation=nn.ReLU):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
self.activation = activation()
self.layer_norm = nn.BatchNorm1d(out_features, track_running_stats=False) if layer_norm else None
self.dropout = nn.Dropout(dropout) if dropout else None

def forward(self, x):
x = self.activation(self.linear(x))
if self.layer_norm:
x = self.layer_norm(x)
if self.dropout:
x = self.dropout(x)
return x


class DGELModel(nn.Module):
# Model initialization
def __init__(self, nb_iprs, nb_gos, nb_zero_gos, nb_rels, device, hidden_dim=1024, embed_dim=1024, margin=0.1):
super().__init__()
self.nb_gos = nb_gos
self.nb_zero_gos = nb_zero_gos
input_length = nb_iprs
# MLP embedding block
net = []
net.append(MLPBlock(input_length, hidden_dim))
net.append(Residual(MLPBlock(hidden_dim, hidden_dim)))
self.net = nn.Sequential(*net)

# ELEmbeddings
self.embed_dim = embed_dim
self.hasFuncIndex = th.LongTensor([nb_rels]).to(device)
self.go_embed = nn.Embedding(nb_gos + nb_zero_gos, embed_dim)
self.go_norm = nn.BatchNorm1d(embed_dim)
k = math.sqrt(1 / embed_dim)
nn.init.uniform_(self.go_embed.weight, -k, k)
self.go_rad = nn.Embedding(nb_gos + nb_zero_gos, 1)
nn.init.uniform_(self.go_rad.weight, -k, k)
# self.go_embed.weight.requires_grad = False
# self.go_rad.weight.requires_grad = False

self.rel_embed = nn.Embedding(nb_rels + 1, embed_dim)
nn.init.uniform_(self.rel_embed.weight, -k, k)
self.all_gos = th.arange(self.nb_gos).to(device)
self.margin = margin

# Main forward pass of DGZ
def forward(self, features):

x = self.net(features)
go_embed = self.go_embed(self.all_gos)
hasFunc = self.rel_embed(self.hasFuncIndex)
hasFuncGO = go_embed + hasFunc
go_rad = th.abs(self.go_rad(self.all_gos).view(1, -1))
# Defined function for prediction score yc
x = th.matmul(x, hasFuncGO.T) + go_rad
logits = th.sigmoid(x)
return logits
# Zero shot prediction
def predict_zero(self, features, data):
x = self.net(features)
go_embed = self.go_embed(data)
hasFunc = self.rel_embed(self.hasFuncIndex)
hasFuncGO = go_embed + hasFunc
go_rad = th.abs(self.go_rad(data).view(1, -1))
x = th.matmul(x, hasFuncGO.T) + go_rad
logits = th.sigmoid(x)
return logits

# EL Embedding model losses, consisting of normal forms 1-4
def el_loss(self, go_normal_forms):
nf1, nf2, nf3, nf4 = go_normal_forms
nf1_loss = self.nf1_loss(nf1)
nf2_loss = self.nf2_loss(nf2)
nf3_loss = self.nf3_loss(nf3)
nf4_loss = self.nf4_loss(nf4)
# print()
# print(nf1_loss.detach().item(),
# nf2_loss.detach().item(),
# nf3_loss.detach().item(),
# nf4_loss.detach().item())
return nf1_loss + nf3_loss + nf4_loss + nf2_loss

def class_dist(self, data):
c = self.go_norm(self.go_embed(data[:, 0]))
d = self.go_norm(self.go_embed(data[:, 1]))
rc = th.abs(self.go_rad(data[:, 0]))
rd = th.abs(self.go_rad(data[:, 1]))
dist = th.linalg.norm(c - d, dim=1, keepdim=True) + rc - rd
return dist

def nf1_loss(self, data):
pos_dist = self.class_dist(data)
loss = th.mean(th.relu(pos_dist - self.margin))
return loss

def nf2_loss(self, data):
c = self.go_norm(self.go_embed(data[:, 0]))
d = self.go_norm(self.go_embed(data[:, 1]))
e = self.go_norm(self.go_embed(data[:, 2]))
rc = th.abs(self.go_rad(data[:, 0]))
rd = th.abs(self.go_rad(data[:, 1]))
re = th.abs(self.go_rad(data[:, 2]))

sr = rc + rd
dst = th.linalg.norm(c - d, dim=1, keepdim=True)
dst2 = th.linalg.norm(e - c, dim=1, keepdim=True)
dst3 = th.linalg.norm(e - d, dim=1, keepdim=True)
loss = th.mean(th.relu(dst - sr - self.margin)
+ th.relu(dst2 - rc - self.margin)
+ th.relu(dst3 - rd - self.margin))

return loss

def nf3_loss(self, data):
# R some C subClassOf D
n = data.shape[0]
# rS = self.rel_space(data[:, 0])
# rS = rS.reshape(-1, self.embed_dim, self.embed_dim)
rE = self.rel_embed(data[:, 0])
c = self.go_norm(self.go_embed(data[:, 1]))
d = self.go_norm(self.go_embed(data[:, 2]))
# c = th.matmul(c, rS).reshape(n, -1)
# d = th.matmul(d, rS).reshape(n, -1)
rc = th.abs(self.go_rad(data[:, 1]))
rd = th.abs(self.go_rad(data[:, 2]))

rSomeC = c + rE
euc = th.linalg.norm(rSomeC - d, dim=1, keepdim=True)
loss = th.mean(th.relu(euc + rc - rd - self.margin))
return loss


def nf4_loss(self, data):
# C subClassOf R some D
n = data.shape[0]
c = self.go_norm(self.go_embed(data[:, 0]))
rE = self.rel_embed(data[:, 1])
d = self.go_norm(self.go_embed(data[:, 2]))

rc = th.abs(self.go_rad(data[:, 1]))
rd = th.abs(self.go_rad(data[:, 2]))
sr = rc + rd
# c should intersect with d + r
rSomeD = d + rE
dst = th.linalg.norm(c - rSomeD, dim=1, keepdim=True)
loss = th.mean(th.relu(dst - sr - self.margin))
return loss

Let’s inspect the DGELModel class in detail. Initialization of the model requires the following parameters:

  • nb_iprs: length of the IPRS dictionary, specific to ontology aspect and input data.
  • nb_gos: length of the GO terms dictionary, specific to ontology asepct and input data.
  • nb_zero_gos:length of the zero-shot GO terms dictionary, obtained through conversion of input data GO terms dictionary into normal forms.
  • nb_rels: length of the relations dictionary, obtained through conversion of input data GO terms dictionary into normal forms.
  • hidden_dim: Size of the hidden dimension of the MLP blocks. Note that this is the same as embed_dim.
  • embed_dim: Size of the embedding layer output. Note that this is kept the same as hidden_dim, to allow interchangeability between protein and ELEmbeddings.
  • margin: Margin value used for loss function calculation.

So what happens during initialization?

  1. The main residual MLP network is built. This produces protein embeddings from the input vector, which consists of present interpro annotations (using iprs_dict) for each example in the input dataframe. Training the network is done using labels obtained from present GO terms (for the ontology aspect) in the input dataframe.
  2. EL Embedding related components are initialized. Recall that the approach utilizes normalized GO axioms as constraints to project each GO term into a n-ball (represented as center point + radius in n-dimensional space) and each relation as a transformation within n-dimensional space. The components are as follows:
  • self.hasFuncIndex is a tensor storing the number of relationships.
  • self.go_embed is an embedding layer for GO terms. It initializes embeddings for nb_gos + nb_zero_gos terms, each with embed_dim dimensions.
  • self.go_norm is a batch normalization layer for the GO term embeddings.
  • self.go_rad is another embedding layer for GO terms, but for radius calculation, with each term is embedded into a single dimension.
  • self.rel_embed is an embedding layer for relationships, with nb_rels + 1 terms.
  • self.all_gos is a tensor containing indices for all GO terms, presumably used later in the forward pass.

These components are then used in the forward pass described earlier in the article.

With the main methods covered, let’s go over the primary training and inference scripts. We begin by initializing our losses, the model, and our dataloader.

loss_func = nn.BCELoss()
iprs_dict, terms_dict = load_data(data_root, ont, terms_file)
n_terms = len(terms_dict)
n_iprs = len(iprs_dict)

nf1, nf2, nf3, nf4, relations, zero_classes = load_normal_forms(go_file, terms_dict)
n_rels = len(relations)
n_zeros = len(zero_classes)

normal_forms = nf1, nf2, nf3, nf4
nf1 = th.LongTensor(nf1).to(device)
nf2 = th.LongTensor(nf2).to(device)
nf3 = th.LongTensor(nf3).to(device)
nf4 = th.LongTensor(nf4).to(device)
normal_forms = nf1, nf2, nf3, nf4


net = DGELModel(n_iprs, n_terms, n_zeros, n_rels, device).to(device)
print(net)

if not load:
train_df = pd.read_pickle(f'{data_root}/{ont}/train_data.pkl')
train_data = get_data(train_df, iprs_dict, terms_dict)
print(train_data[0].shape)
train_loader = FastTensorDataLoader(
*train_data, batch_size=batch_size, shuffle=True)
del train_df,train_data

valid_df = pd.read_pickle(f'{data_root}/{ont}/valid_data.pkl')
valid_data = get_data(valid_df, iprs_dict, terms_dict)
print(valid_data[0].shape)
valid_loader = FastTensorDataLoader(
*valid_data, batch_size=batch_size, shuffle=False)

del valid_df, valid_data

Next, we initialize our optimizer, the learning rate, and begin the training & validation processes.

if not load:
optimizer = th.optim.Adam(net.parameters(), lr=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[5, 20], gamma=0.1)
best_loss = 10000.0

print('Training the model')
for epoch in range(epochs):
net.train()
train_loss = 0
train_elloss = 0
lmbda = 0.1
train_steps = len(train_loader)
for batch_features, batch_labels in train_loader:
batch_features = batch_features.to(device)
batch_labels = batch_labels.to(device)
logits = net(batch_features)
loss = F.binary_cross_entropy(logits, batch_labels)
el_loss = net.el_loss(normal_forms)
total_loss = loss + el_loss
train_loss += loss.detach().item()
train_elloss = el_loss.detach().item()
optimizer.zero_grad()
total_loss.backward()
optimizer.step()

train_loss /= train_steps

net.eval()
with th.no_grad():
valid_steps = len(valid_loader)
valid_loss = 0
preds = []
for batch_features, batch_labels in valid_loader:
batch_features = batch_features.to(device)
batch_labels = batch_labels.to(device)
logits = net(batch_features)
batch_loss = F.binary_cross_entropy(logits, batch_labels)
valid_loss += batch_loss.detach().item()
preds = np.append(preds, logits.detach().cpu().numpy())
valid_loss /= valid_steps
roc_auc = compute_roc(valid_labels, preds)
print(f'Epoch {epoch}: Loss - {train_loss}, EL Loss: {train_elloss}, Valid loss - {valid_loss}, AUC - {roc_auc}')

print('EL Loss', train_elloss)
if valid_loss < best_loss:
best_loss = valid_loss
print('Saving model')
th.save(net.state_dict(), model_file)

scheduler.step()
Figure 8. Training & Validation metrics

We see that in 25 epochs, we achieve an AUC of over 0.94 — further improvements would probably require changes to the learning rate. For now, we move to testing and inference. Note that for inference, we leverage the semantics of the GO ontology axioms to propagate prediction scores to ancestor terms in the ontology and updates the scores accordingly.

# Load ontology for inference post-processing
go = Ontology(f'{data_root}/go.obo', with_rels=True)

# Loading best model
print('Loading the best model')
net.load_state_dict(th.load(model_file, map_location=device))
net.eval()

with th.no_grad():
test_steps = int(math.ceil(len(test_labels) / batch_size))
test_loss = 0
preds = []
for batch_features, batch_labels in tqdm(test_loader,total=len(test_loader)):
batch_features = batch_features.to(device)
batch_labels = batch_labels.to(device)
logits = net(batch_features)
batch_loss = F.binary_cross_entropy(logits, batch_labels)
test_loss += batch_loss.detach().cpu().item()
preds = np.append(preds, logits.detach().cpu().numpy())
test_loss /= test_steps
preds = preds.reshape(-1, n_terms)
roc_auc = compute_roc(test_labels, preds)
print(f'Test Loss - {test_loss}, AUC - {roc_auc}')


preds = list(preds)

# Propagate scores to ancestor terms using ontology structure
for i, scores in tqdm(enumerate(preds), total=len(preds)):
prop_annots = {}
for go_id, j in terms_dict.items():
score = scores[j]
for sup_go in go.get_anchestors(go_id):
if sup_go in prop_annots:
prop_annots[sup_go] = max(prop_annots[sup_go], score)
else:
prop_annots[sup_go] = score
for go_id, score in prop_annots.items():
if go_id in terms_dict:
scores[terms_dict[go_id]] = score

test_df['preds'] = preds

test_df.to_pickle(out_file)

A snapshot of the test loss and AUC along with the output dataframe is shown below, displaying the predicted annotations and associated confidences.

Figure 9. Test metrics and snapshot of output test dataframe, with predicted annotations and scores.

Conclusions

In our tutorial, we’ve demonstrated how DeepGOZero combines a model-theoretic approach for learning ontology embeddings with neural networks to improve protein function prediction accuracy. By exploiting formal axioms in GO, DeepGOZero is able to perform well even on zero-shot predictions. Note that by the time of publication, the authors have already released an updated approach in the form of DeepGO-SE, which relies on a combination of LLM and neuro-symbolic approaches to further improve performance on zero-shot prediction cases.

In our next article, we’ll share some thoughts on leveraging DeepGOZero for the newer CAFA5 challenge. Please drop a follow to stay tuned for more relevant content at the interface of ML and the Natural Sciences.

Sources

[1]The Gene Ontology Consortium

[2]Goh, Data Science for Biologists — Databases, NUS

[3]Kulmanov et. al, EL Embeddings: Geometric construction of models for the Description Logic

[4]Kulmanov et. al, DeepGOZero: improving protein function prediction from sequence and zero-shot learning based on ontology axioms

[5]Peng et. al, Description Logic EL++Embeddings with Intersectional Closure

[6]Baek M. et al. (2021) Accurate prediction of protein structures and interactions using a three-track neural network. Science, 373, 871–876

[7]Jumper J. et al. (2021) Highly accurate protein structure prediction with AlphaFold. Nature, 596, 583–589

[8]Pan J. et al. (2021) Sequence-based prediction of plant protein-protein interactions by combining discrete sine transformation with rotation forest. Evol. Bioinform. Online, 17, 11769343211050067

[9]Sledzieski S. et al. (2021) D-script translates genome to phenome with sequence-based, structure-aware, genome-scale predictions of protein-protein interactions. Cell Syst., 12, 969–982

[10]Radivojac P. et al. (2013) A large-scale evaluation of computational protein function prediction. Nat. Methods., 10, 221–227

[11]Kulmanov M., Hoehndorf R. (2019) DeepGOPlus: improved protein function prediction from sequence. Bioinformatics, 36, 422–429.

[12]You R. et al. (2021) DeepGraphGO: graph neural network for large-scale, multispecies protein function prediction. Bioinformatics, 37, 262–271

[13]Smith B. et al. (2005) Relations in biomedical ontologies. Genome Biol., 6, R46

--

--