Snowflake in drug discovery: Fine tuned ChemBerta for toxicity prediction

By Jason Shi and Harini Gopalakrishnan

This series delves into the possibilities of leveraging Snowflake as a foundational element for research analytics, specifically storing omics, small molecule, and large molecule data for various tasks.It aims to highlight the transformative role of Snowflake, offering both data governance and scalability, in aligning with the paradigm of bringing compute to data. This blog was also part of the Lifesciences Masterclass webinar which can be viewed offline in this link. For a deeper dive on the implementation methodology please read the follow up blog by Jason here

Image Reference: https://pubs.acs.org/doi/abs/10.1021/acs.jmedchem.9b02147

Introduction to Chemical structure and drug discovery workflow

It’s well known in the pharma and biotech industries that making drugs is expensive and high risk and a lengthy endeavor. It costs anywhere between 300 million and $5 billion dollars to bring a drug to the market and that process takes about ten years. It was estimated that pharmaceutical companies spent $2.6 billion in 2015, up from $802 million in 2003, on drug approval by the US Food and Drug Administration as seen in this article.

Drug discovery and development is a complex task that involves optimization of vital properties of candidate compounds that we term physicochemical properties. Medicinal chemists select, design and prioritize molecular structures on the basis of these factors including the desired biological activity of the compounds. These characteristics important for drugs include various factors such as absorption, distribution, metabolism, excretion and toxicity (ADMET) properties and in order to accerate drug screening timelines medicinal chemists rely on AI/ML for such property predictions which work on flattened or vectorized representations of small molecules. Compounds are generally stored as SMILES that look like the ones shown in Figure 1 below.

Figure 1: SMILE representation of chemical structures

ChemBERTa and relevance of LLM’s in property prediction

Traditionally predicting properties of these compounds were undertaken by specialized packages (like RDKIT) which operate either on the SMILES directly or an underlying fingerprint representation of these compounds (like cartridges in oracle). You can read an example of performing structure search with RDKIT & Snowpark in this article. However , given the recent advent of LLMs and transformers there have been attempts to leverage them for similar tasks in order to provide more speed and accuracy to the optimization problem. ChemBERTa, with 12 attention heads, 6 layers and 72 distinct attention mechanisms is one such transformer model that has been made available by the open source community. More details of ChemBERTa can be read here. It employs a bidirectional training context to learn context-aware representations of the PubChem 10M dataset. By pre-training directly on SMILES strings, and teaching ChemBERTa to recognize masked tokens in each string, the model learns a strong molecular representation. We then can take this model, trained on a structural chemistry task, and apply it to a suite of classification tasks in the MoleculeNet suite.

Finetuning ChemBERTA on Snowflake with Snowpark Container Services

Snowflake recently published its support for Snowpark container services (SPCS) which allows you to containerize and bring your own models for finetuning and inferencing. It now provides the ability for snowflake users to fine tune domain specific models and set up subsequent inferencing without having to extract the data and break the provenance. In this example, we will fine tune ChemBERTa with ClinTox, a MoleculeNet public dataset, to see how it performs on molecule toxicity prediction. The dataset consists of qualitative data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons.

Our total ClinTox dataset consists of 1478 binary labels for toxicity, using the SMILES representations for identifying molecules. Figure 2 provides an overview of the conceptual workflow in Snowflake for the end to end task right from downloading ClinTox data sets into an external stage to leveraging Snowflake’s container services for fine tuning to subsequent set up of an inferencing model for user interaction.

Figure 2: Conceptual schema of the training in Snowflake

Deep diving into the Snowflake impelementation pattern

In order to orchestrate this entire pipeline the following steps were followed:

Step 1: Data preparation

Data prep was a key step since we dealt with smiles representation of the compounds. Since Snowflake can handle multiple modalities of data, the data remained within Snowflake tables as SMILE’s strings from the original CSV made available by Clintox. Snowflake can automatically read CSVs and infer the data types and columns. We then leveraged deepchem python libraries to prep the moleculenet datasets in order to break them down into training, test and validation. We finetuned ChemBERTa with two different tokenizers to compare the results. One is Byte Pair Encoder (BPE), which is commonly used by a lot of Transfomrer models. The other is SMILES tokenizer, which splits SMILES sequences into syntactically relevant chemical toke. The performance and accuracy was much better with the SMILES tokenizer.

Step 2: Fine tuning

Transfer learning is a research problem in machine learning that focuses on storing knowledge gained while solving one problem and applying it to a different but related problem. By pre-training directly on SMILES strings, and teaching ChemBERTa to recognize masked tokens in each string, the model learns a stronger molecular representation. We then can take this model, trained on a structural chemistry task, and apply it to a suite of classification tasks in the MoleculeNet suite. In this context we leveraged the smiles strings pre loaded into Snowflake tables and fed them into the model for classification task based on the known toxicity values. The Jupyter notebook used for finetuning was also hosted within Snowpark container services thus ensuring the data never leaves Snowflake environment during .

Figure 3: Snapshot of the finetuning argument

Step 3: Inferencing and Explainability

The fine tuned model was saved in a Snowflake internal stage. A streamlit app that ran in container services interacted with the inference model allowing users to make predictions real time. For inference, input data provided by the user is a SMILES string. Predictions include the predicted label (0 indicates the molecule is not toxic, 1 indicates toxic), and upper bound, lower bound of the prediction. The final streamlit output looks like the image in Figure 4

Figure 4: Streamlit app with Toxicity prediction

Explainability has always been an issue with deep learning models as Gen AI can provide completely different answers for the same question in a sense of what we call hallucination. In order to combat that, we leveraged Bertviz in order to create visualization in each layer to trace through the output. Bertviz is a general purpose explainability for BERT based model and in this instance helps trace through the neurons that was activated as it went through the six layers in order to ensure the toxicity predictions are grounded in reality. The overall architecture looks as in Figure 5

Figure 5: Overall architecture for the end to end workflow

Model performance metrics with weights and biases

The model performance was pretty good and we used weights and biases APIs to track the results. The final inference can be seen in the figure below

Figure 6: Confusion matrix for finetiuned ChemBERTA. Visualized in weights and bias.

Why Snowflake and what next? Whats upcoming?

With the Chemberta model being hosted in secured Snowpark container services, the data remains in the high-performance Snowflake store while fine tuning and setting up an inference model for the likes of toxicity prediction are performed as part of a drug discovery pipeline. The computational models produced from the dataset could become decision-making tools for government agencies in determining which drugs are of the greatest potential concern to human health. Additionally, these models can act as drug screening tools in the drug discovery pipelines for toxicity

Extending this Chemberta story further, we can in addition string together multiple down stream and upstream processes that are vital components to a drug discovery pipeline including structure and substructure searches as demonstrated in this blog. In addition it is also possible to manage large molecule files like PDB as demonstrated in this article

From a data science perspective Snowpark container services provides ML engineers with the following advantages:

  • Easy and fast access to GPUs. There are warm compute pools that get provisioned as and when you need them.
  • Zero data movement and foundational data security. SPCS interact with the data within Snowflake governed security premises
  • Fully managed, scalable containerized environment for deep learning. Snowflake becomes the single environment for both inference and deep learning tasks.

Finally, we can also fine tune chemBERTa on several other classification tasks from MoleculeNet selected to cover a range of dataset sizes and medicinal chemistry applications like brain penetrability to solubility predictions which with the help our our newly launched Cortex AI functions open up the realm of possibility to perform hybrid searches with a combination of vectorized indexes and regular data.

Note: the code for fine tuning can be found in the link here

Additional Reads

Follow up on Chemberta fine tuning for technical deep dive

https://medium.com/snowflake/cheminformatics-in-snowflake-using-rdkit-snowpark-to-analyze-molecular-data-9136afb2b10f

https://medium.com/snowflake/snowflake-in-drug-discovery-dc87a25f40a0

https://medium.com/snowflake/accelerating-drug-discovery-research-with-ai-ml-building-blocks-and-streamlit-751e0cf92844

--

--