Snowflake in Drug Discovery: Fine-tuned ChemBERTa for Toxicity Prediction, a Deep Dive

This series delves into the possibilities of leveraging Snowflake as a foundational element for research analytics, specifically storing omics, small molecule, and large molecule data for various tasks. It aims to highlight the transformative role of Snowflake, offering both data governance and scalability, in aligning with the paradigm of bringing compute to data. This blog was also part of the Life Sciences Masterclass webinar which can be viewed offline in this link.

In the previous blog, we discussed fine tuning ChemBERTa and the relevance to the pharmaceutical industry. In this blog, we take a deep dive on how fine-tuning is implemented in Snowflake.

Introduction

Transformer has emerged as a robust architecture for learning self-supervised representations of text. Transformer pre-training plus task-specific fine-tuning provides substantial gains over previous approaches to many tasks in natural language processing (NLP). ChemBERTa is a transformer based model that leverages NLP techniques to understand SMILES strings for molecular property predictions. SMILES is a specification in the form of a line notation for describing the structure of chemical species.

ChemBERTa employs a bidirectional training context to learn context-aware representations of PubChem 10M dataset. A variant of the BERT model, ChemBERTa contains 12 attention heads and 6 layers, resulting in 72 distinct attention mechanisms. We fine tuned ChemBERTa using ClinTox dataset, a MoleculeNet open dataset, to see how it performs on molecular toxicity prediction. We built a Streamlit app for researchers to make real-time predictions. We open sourced the code for experimentation.

Snowflake Implementation Deep Dive

System Architecture

Figure 1: System Architecture

Snowpark Container Services (SPCS) is a scalable, compartmentalized environment for researchers to fine tune deep learning models. Researchers can directly download pre-trained models from Transformers (Hugging Face), Pytorch, Tensorflow, and fine tune models with open and proprietary datasets. SPCS enables quick access to GPUs (NVIDIA A10G, NVIDIA A100) and CPUs. In this case, we use Jupyter Notebook running in SPCS as our IDE. We install and import libraries like Transformers, RDKit, and DeepChem. Fine-tuned model artifacts are saved in a Snowflake internal stage. Snowpark ML Model Registry can also be used to save the model artifacts. A Streamlit app, secured by role-based access control, is running in SPCS for users to make real-time predictions.

Inference with Pre-trained ChemBERTa

Before fine-tuning, we download pre-trained ChemBERTa from Hugging Face and make inferences on a masked SMILES string. ChemBERTa was initially trained to predict masked SMILES tokens. We create an inference pipeline to input a SMILES string, call the model, and make predictions.

Figure 2: Inference Pipeline

ChemBERTa successfully predicts the masked token as an equal sign, a double carbon bond, with a 0.9694 confidence score.

Figure 3: Top Five Predictions from ChemBERTa

Visualize the Attention Mechanism in ChemBERTa Using BertViz

BertViz is a tool for visualizing attention layers in Transformer models, supporting all models from the Transformers library. Using BertViz, we visualize the attention patterns produced by one or more attention heads in a given transformer layer in ChemBERTa. Figure 4 shows a clear path of activated neurons in layer zero that contributed to the final prediction.

Figure 4: Neuron-by-Neuron View

Fine Tune ChemBERTa for Molecular Toxicity Prediction

Pre-trained on PubChem datasets, ChemBERTa performs quite well predicting masked SMILES tokens. In learning to recover masked tokens, the model forms a representational topology of chemical space that should generalize to property prediction tasks. The ClinTox dataset compares drugs approved by the FDA and drugs that have failed clinical trials for toxicity reasons. The dataset includes two classification tasks for 1491 drug compounds with known chemical structures. We fine tuned ChemBERTa with the ClinTox dataset, hence re-purpose the pre-trained model for molecular toxicity prediction to accelerate drug discovery process.

Data Preparation

ClinTox dataset is publicly available. We download the dataset from a public S3 bucket, use DeepChem to load and split data into training, test, and validation datasets. Figure 5 shows how we leverage DeepChem for data preparation. Snowpark API can also be used to securely access proprietary datasets stored in Snowflake to fine tune the model.

Figure 5: Download Dataset

Model Training

We fine tune the model with two different tokenizers to compare the results. One is Byte-Pair Encoder (BPE) tokenizer, which is commonly used by many Transformer models. The other is SMILES tokenizer, which splits SMILES sequences into syntactically relevant chemical tokens. We observe superior model performance using the SMILES tokenization strategy.

Fine Tune with SMILES Tokenizer

We use Simple Transformers, a library built on Transformers, to load pre-trained ChemBERTa, and set hyperparameters such as number of epochs, learning rate for fine tuning.

Figure 6: Load Pre-trained ChemBERTa with SMILES Tokenizer

We call the train model function to kick off the training process. Model artifacts and logs are saved in the specified output directory.

Figure 7: Train ChemBERTa

Make Predictions with Fine-tuned ChemBERTa

After only ten epochs, the model is already performing well. We then make predictions on unseen data, a SMILES string.

Figure 8: Make Predictions

Fine-tuned ChemBERTa correctly predicts the label to be 1.

Figure 9: Predictions

Fine tuning with BPE Tokenizer follows a similar training process. The model performs better with SMILES tokenizer.

Results

We only used a subset of the ClinTox dataset and NVIDIA A10 GPUs for training. The fine-tuned model performs well, especially predicting true positives. Model performance can be further improved using the more powerful NVIDIA A100 and a larger dataset.

Figure 10: Confusion Matrix

Experiment Tracking with Weights & Biases

Weights & Biases helps AI developers to track model experiments and evaluate model performance. We use Weights & Biases for experiment tracking. After logging in via an API key, model metrics such as ROC-AUC curve, confusion matrix, and GPU usage are automatically collected for further evaluation.

Figure 11: W&B

Inference with Streamlit

Streamlit is an open-source Python library for machine learning teams to easily build web apps for users to make real-time predictions with deployed models. The fine-tuned model takes a SMILES string as input. Outputs are a predicted label (0 indicates non-toxic, 1 indicates toxic) and upper bound, lower bound predictions.

Figure 12: Streamlit App

Below shows it only takes 12 lines of code to build a web app for ChemBERTa. Streamlit apps can run in Snowflake UI and Snowpark Container Services.

Snowpark Container Services

Machine learning teams can fine tune open-source models, and build models from scratch in Snowpark Container Services. Third-party machine learning tools, libraries, and deep learning frameworks can be easily integrated into a SPCS environment. Primary reasons Machine Learning Engineers use SPCS for deep learning are:

  • Easy and fast access to GPUs. There are warm compute pools that get provisioned as and when you need them.
  • Zero data movement and foundational data security. SPCS interacts with the data within Snowflake governed security premises.
  • Fully managed, scalable containerized environment for deep learning. Snowflake becomes the single environment for data storage, deep learning tasks, model management and inference.

Future Work

ChemBERTa can be fine tuned with many MoleculeNet open datasets, predicting on a wide range of classification and regression tasks in biophysics, physiology, and physical chemistry. Aqueous solubility prediction and HIV replication inhibition prediction are two examples. Proprietary datasets stored in Snowflake can also be used for fine-tuning. Code is open sourced, Github.

Acknowledgement

Based on ChemBERTa author Seyone Chithrananda’s work. arXiv paper.

Additional Reads

Snowflake in Drug Discovery: Fine tuned ChemBerta for toxicity prediction

Snowflake in Drug Discovery: Leveraging BioNeMo as an external LLM & Snowpark container services for the Protein folding problem

Accelerating Drug Discovery Research with AI/ML Building Blocks and Streamlit

--

--