Snowflake in Drug Discovery: Fine-tuned ChemBERTa for Toxicity Prediction, a Deep Dive
This series delves into the possibilities of leveraging Snowflake as a foundational element for research analytics, specifically storing omics, small molecule, and large molecule data for various tasks. It aims to highlight the transformative role of Snowflake, offering both data governance and scalability, in aligning with the paradigm of bringing compute to data. This blog was also part of the Life Sciences Masterclass webinar which can be viewed offline in this link.
In the previous blog, we discussed fine tuning ChemBERTa and the relevance to the pharmaceutical industry. In this blog, we take a deep dive on how fine-tuning is implemented in Snowflake.
Introduction
Transformer has emerged as a robust architecture for learning self-supervised representations of text. Transformer pre-training plus task-specific fine-tuning provides substantial gains over previous approaches to many tasks in natural language processing (NLP). ChemBERTa is a transformer based model that leverages NLP techniques to understand SMILES strings for molecular property predictions. SMILES is a specification in the form of a line notation for describing the structure of chemical species.
ChemBERTa employs a bidirectional training context to learn context-aware representations of PubChem 10M dataset. A variant of the BERT model, ChemBERTa contains 12 attention heads and 6 layers, resulting in 72 distinct attention mechanisms. We fine tuned ChemBERTa using ClinTox dataset, a MoleculeNet open dataset, to see how it performs on molecular toxicity prediction. We built a Streamlit app for researchers to make real-time predictions. We open sourced the code for experimentation.
Snowflake Implementation Deep Dive
System Architecture
Snowpark Container Services (SPCS) is a scalable, compartmentalized environment for researchers to fine tune deep learning models. Researchers can directly download pre-trained models from Transformers (Hugging Face), Pytorch, Tensorflow, and fine tune models with open and proprietary datasets. SPCS enables quick access to GPUs (NVIDIA A10G, NVIDIA A100) and CPUs. In this case, we use Jupyter Notebook running in SPCS as our IDE. We install and import libraries like Transformers, RDKit, and DeepChem. Fine-tuned model artifacts are saved in a Snowflake internal stage. Snowpark ML Model Registry can also be used to save the model artifacts. A Streamlit app, secured by role-based access control, is running in SPCS for users to make real-time predictions.
Inference with Pre-trained ChemBERTa
Before fine-tuning, we download pre-trained ChemBERTa from Hugging Face and make inferences on a masked SMILES string. ChemBERTa was initially trained to predict masked SMILES tokens. We create an inference pipeline to input a SMILES string, call the model, and make predictions.
ChemBERTa successfully predicts the masked token as an equal sign, a double carbon bond, with a 0.9694 confidence score.
Visualize the Attention Mechanism in ChemBERTa Using BertViz
BertViz is a tool for visualizing attention layers in Transformer models, supporting all models from the Transformers library. Using BertViz, we visualize the attention patterns produced by one or more attention heads in a given transformer layer in ChemBERTa. Figure 4 shows a clear path of activated neurons in layer zero that contributed to the final prediction.
Fine Tune ChemBERTa for Molecular Toxicity Prediction
Pre-trained on PubChem datasets, ChemBERTa performs quite well predicting masked SMILES tokens. In learning to recover masked tokens, the model forms a representational topology of chemical space that should generalize to property prediction tasks. The ClinTox dataset compares drugs approved by the FDA and drugs that have failed clinical trials for toxicity reasons. The dataset includes two classification tasks for 1491 drug compounds with known chemical structures. We fine tuned ChemBERTa with the ClinTox dataset, hence re-purpose the pre-trained model for molecular toxicity prediction to accelerate drug discovery process.
Data Preparation
ClinTox dataset is publicly available. We download the dataset from a public S3 bucket, use DeepChem to load and split data into training, test, and validation datasets. Figure 5 shows how we leverage DeepChem for data preparation. Snowpark API can also be used to securely access proprietary datasets stored in Snowflake to fine tune the model.
Model Training
We fine tune the model with two different tokenizers to compare the results. One is Byte-Pair Encoder (BPE) tokenizer, which is commonly used by many Transformer models. The other is SMILES tokenizer, which splits SMILES sequences into syntactically relevant chemical tokens. We observe superior model performance using the SMILES tokenization strategy.
Fine Tune with SMILES Tokenizer
We use Simple Transformers, a library built on Transformers, to load pre-trained ChemBERTa, and set hyperparameters such as number of epochs, learning rate for fine tuning.
We call the train model function to kick off the training process. Model artifacts and logs are saved in the specified output directory.
Make Predictions with Fine-tuned ChemBERTa
After only ten epochs, the model is already performing well. We then make predictions on unseen data, a SMILES string.
Fine-tuned ChemBERTa correctly predicts the label to be 1.
Fine tuning with BPE Tokenizer follows a similar training process. The model performs better with SMILES tokenizer.
Results
We only used a subset of the ClinTox dataset and NVIDIA A10 GPUs for training. The fine-tuned model performs well, especially predicting true positives. Model performance can be further improved using the more powerful NVIDIA A100 and a larger dataset.
Experiment Tracking with Weights & Biases
Weights & Biases helps AI developers to track model experiments and evaluate model performance. We use Weights & Biases for experiment tracking. After logging in via an API key, model metrics such as ROC-AUC curve, confusion matrix, and GPU usage are automatically collected for further evaluation.
Inference with Streamlit
Streamlit is an open-source Python library for machine learning teams to easily build web apps for users to make real-time predictions with deployed models. The fine-tuned model takes a SMILES string as input. Outputs are a predicted label (0 indicates non-toxic, 1 indicates toxic) and upper bound, lower bound predictions.
Below shows it only takes 12 lines of code to build a web app for ChemBERTa. Streamlit apps can run in Snowflake UI and Snowpark Container Services.
Snowpark Container Services
Machine learning teams can fine tune open-source models, and build models from scratch in Snowpark Container Services. Third-party machine learning tools, libraries, and deep learning frameworks can be easily integrated into a SPCS environment. Primary reasons Machine Learning Engineers use SPCS for deep learning are:
- Easy and fast access to GPUs. There are warm compute pools that get provisioned as and when you need them.
- Zero data movement and foundational data security. SPCS interacts with the data within Snowflake governed security premises.
- Fully managed, scalable containerized environment for deep learning. Snowflake becomes the single environment for data storage, deep learning tasks, model management and inference.
Future Work
ChemBERTa can be fine tuned with many MoleculeNet open datasets, predicting on a wide range of classification and regression tasks in biophysics, physiology, and physical chemistry. Aqueous solubility prediction and HIV replication inhibition prediction are two examples. Proprietary datasets stored in Snowflake can also be used for fine-tuning. Code is open sourced, Github.
Acknowledgement
Based on ChemBERTa author Seyone Chithrananda’s work. arXiv paper.
Additional Reads
Snowflake in Drug Discovery: Fine tuned ChemBerta for toxicity prediction
Accelerating Drug Discovery Research with AI/ML Building Blocks and Streamlit