The Startup
Published in

The Startup

Drug Discovery With Neural Networks

A summary of the Mechanisms of action (MoA) prediction competition on kaggle where we used deep learning algorithms to predict the MoA of new drugs.

Figure 1. Overview (1) Human cells are treated with a drug. (2) Gene expression and cell viability measurements. (3) The data is fed to the neural network. (4) the NN predicts if a drug has a MoA or not. [Image by Author]

Discovering a new drug has always been a long process that takes years. With the recent advances of AI and the accumulation of research data in biological databases, the drug discovery process and the research pace are getting faster than ever. Researchers in the laboratory of innovation and science at Harvard are working on the Connectivity MAP project [1] with the goal of advancing drug development through improvements to the drugs MoA prediction algorithms. This challenge was launched as a kaggle competition [2] in order to build machine learning models to predict the MoA of unknown drugs.

1- Dataset:

We start by understanding the competition’s dataset: We have a dataset with gene expression and cell viability data as features and 206 MoA as targets.

1–1 Gene expression features:

Figure 2. Gene expression assay: The gene expression assay undergoes a series of steps starting from the selection of cell lines, drug treatment, incubation, to mRNA extraction and quantification of a subset of genes. [Image by Author]

The gene expression is measured for 772 genes in this dataset.

Figure 3. Distribution of 5 genes: g-1, g-2, g-3, g-400 and g-771 (out of 772 genes). [Image by Author]

Gene expression was measured with L1000 assay. You can learn more about this new technology in the Connectivity map webpage and the research paper: A Next Generation Connectivity Map: L1000 Platform and the First 1,000,000 Profiles,” Cell, 2017.[3]

For the sake of simplicity, just one experimental condition was explained. In fact, one single drug was profiled several times at different dosages (low and high) and different treatment times (24H, 48H and 72H).

1-2 Cell viability features:

Along with the gene expression data of 772 genes, cell viability data was provided for 100 cell lines, the cell viability assay is based on PRISM (Profiling Relative Inhibition Simultaneously in Mixture).[4]

Figure 4. Cell viability assay:Cells are pooled and treated with the drug, incubation, dead cells count per cell line. [Image by Author]
Figure 5. Cell viability distribution of 5 cell lines (OUT OF 100). [Image by Author]

The cell viability assessment is based on PRISM . You can learn more about this new technology in the Connectivity map webpage and the research paper: Discovering the anticancer potential of non-oncology drugs by systematic viability profiling.

Unlike the gene expression values that represent the mixture of the 100 cell lines, the cell viability values are per cell line, in other words:

  • Gene-1 values are the average of the gene-1 expression over 100 cell lines as explained in figure 2 step 4.
  • Cell-1 value is the viability of the cells belonging to cell line 1 as explained in figure 4.

So far, we’ve seen the gene expression features and the cell viability features after the treatment with the drugs. The only puzzle missing is the drugs mechanism of action, which is the target to predict.

1-3 Targets: Drugs MoA

In pharmacology, the term mechanism of action (MoA) refers to the specific biochemical interaction through which a drug substance produces its pharmacological effect.[5]

Let’s make this definition simpler, for example, the drug aspirin reduces pain and inflammation, so the MoA of aspirin:

  • MoA’s function: Reducing pain and inflammation.
  • MoA’s biochemical function: Involves irreversible inhibition of the enzyme cyclooxygenase, therefore suppressing the production of prostaglandins and thromboxanes, thus, reducing pain and inflammation.

This function or MoA is just one of the possible functions/MoA that the drug Aspirin can have, so one drug can have more than one mechanism of action. This makes the drugs MoA prediction a multi-label problem. We were provided 206 MoA targets per drug, labeled as (0: No MoA, 1: MoA). The table below displays 4 targets (out of 206 targets provided in this dataset).

  • sig_id: is the sample containing the mixture of 100 cell lines treated with a drug-X. (step 1)
  • 5-alpha_reductase_inhibitor, 11-beta-hsd1_inhibitor, acat_inhibitor… are the target mechanisms of action
Table 1. Example of 4 MoA targets out of 206 laballed (0: No MoA) and (1: MoA).

Let’s take the first row:

  • sig_id: ‘id_d00440fe6’ is a mixture of 100 cell lines (see step 1), it was treated with a drug X (see step 2), this drug X doesn’t have the MoA ‘5-alpha_reductase_inhibitor’ so it’s labeled as 0, but it has the MoA ‘acat_inhibitor’ so it’s labeled as 1.

Problem statement: 100 cell lines are treated with a drug. Gene expression and cell viability data is collected to understand the biological activity of this drug. The task is to predict the MoA of new drugs based on the gene expression and cell viability features. (See figure 1)

You can learn more about the competition’s data and the features interaction (genes, cells and drugs) in my kaggle notebook: Drugs MoA classification: EDA.

2- Drugs MoA prediction:

We arrive to the most exciting part of this analysis, the prediction of the mechanism of action of new drugs based on their gene expression and cell viability features.

While deep learning is dominating computer vision and natural language processing tasks, tree based algorithms (Random forrest, decision trees…) and Gradient boosting machines (XGBoost, LGBM, CatBoost…) are still the way to go with tabular data. However, this is not the case here, deep learning algorithms outperformed gradient boosting machines. Why is that? Because we have a multi-label problem with 206 targets to predict. Shallow machine learning algorithms don’t support multi-label tasks, in other words, they do not make use of the 206 targets correlation and cooccurrences to improve the accuracy of their predictions.

To have a better idea, let’s compare the performance of Ridge, LGBM, XGBoost and 3 deep learning models in this competition.

Figure 6. Models performance in the MoA prediction competition on Kaggle. In red, shallow machine learning models and in green deep learning models. [Image by Author]

Those scores are approximate, to have a better idea how those models performed you can check out the notebooks training them on kaggle: Ridge, LGBM, XGBoost, ResNet, 4 layers NN and TabNet.

The take-away from the figure above is the gap between shallow ML models and deep learning models. Deep learning models outperformed in this competition because of their ability to extract information from the 206 targets connections.

The score difference between shallow machine learning models and neural networks might look small for a metric such as log-loss. I’d like to point out that neural networks extract signals from the targets correlation, that’s why they outperformed in this case, however, the correlation between the 206 targets in this dataset was very poor (figure 7), most targets had 0 correlation and only 13 target-pairs had a +0.3 correlation, so there wasn’t much signal to extract, therefore, the difference in log-loss score would be much higher if the targets were more connected.

Figure 7. Heatmap: Correlation between the 206 targets. [Image by Author]

In the following section, I would like to talk about 3 deep learning architectures that performed really well on tabular data with multi-label targets.

2–1. Multi layer perceptron:

MLP or a simple feed forward neural network, the simplest neural networks architecture with 4 dense layers (first 2 layers with 2048 neurons and the last two with 1048 neurons) along with dropout, batch normalization layers and ReLu activation function, performed surprisingly well.

Figure 8. 4 fully connected layers neural network architecture. Each block consist of a batchNorm, dropout, dense layers and a ReLu activation function. [Image by Author]

The score achieved with this model was very competitive with an Adam optimizer, Reduce on plateau scheduler and a binary cross-entropy BCE With Logits Loss function that includes a sigmoid activation.

CODE: 4 layers MLP code Github repository. [9]

2–2 TabNet:

TabNet was introduced in 2019 by google cloud AI in the paper: TabNet: Attentive Interpretable Tabular Learning. It’s a deep learning model for tabular data. TabNet combines the properties of neural networks and tree-based algorithms: [6]

  • It has the power of neural networks to fit and learn complex functions with a high number of parameters.
  • And it has a feature selection mechanism similar to tree-based algorithms. It also uses the attention mechanism in feature selection.
Figure 9. TabNet architecture. Source:

The pytorch implementation of TabNet was done by dreamquark-ai in their tabnet github repository and introduced to kaggle by optimo in his TabNet Regressor notebook. Tuning and understanding the hyper-parameters can lead to a very powerful model. In fact, TabNet was the strongest single model in the MoA prediction competition, outperforming all the other models.

CODE: TabNet code in the Github repository. [9]

2–3 DeepInsight CNN:

DeepInsight is a methodology to transform a non-image data to an image for convolutional neural network architecture [7], this enables taking advantage of the strong pretrained CNN models like EfficientNets. This approach was published in Scientific reports of nature 2019 and introduced to kaggle by Mark Peng in his image transformation tutorial and inference notebooks.

Converting tabular data to image data starts by allocating the features in a feature matrix, where the location of features depends on the similarity of features, so we end up with a feature matrix with several clusters, in each cluster, similar and highly correlated features are grouped together (figure below).

Figure 10. DeepInsight pipeline. (a) Transform a feature vector to afeature matrix. (b) Transform a feature vector to image pixels. Source:

The power of this methodology with gene expression data consists in the arrangement of similar genes into clusters, which makes the differences more accessible and allows for robust identification of hidden mechanisms than dealing with elements individually. Feeding those feature matrix images to CNNs helps to catch the small variation in genomic data with the power of the convolution and pooling layers.

To better understand how DeepInsight transformation [8] works with our data, let’s plot the feature matrices representing the gene expression and cell viability data of 2 samples treated with 2 different targets (MoA): proteasome inhibitor and DNA inhibitor.

Figure 11. Feature matrices representing the transformed gene expression and cell viability data of 2 samples treated with 2 different drugs. [Image by Author]

The difference between those 2 images is clear, the sample treated with a drug having an active proteasome inhibitor has a different feature distribution and correlation than the sample treated with a drug with DNA inhibitor. This allows the pretrained convolutional neural networks to learn patterns that other models fed with tabular data can not catch.

Training a pretrained efficientNet B3 and B4 model with the deepInsight transformed images achieved competitive results, and better than that, it gave a huge boost to the final ensemble with the other neural network models since it learned new patterns only accesible in the images.

CODE: Image transformation + efficientNet B4 code in the github repo. [9]


Along with the models mentioned in this article, other models performed well in this competition with tabular data and multi-label targets such as LSTM and GRU that are sequential models. The advances of AI are getting us to the point of solving a tabular data problem with an ensemble of CNNs and RNNs.


The diagrams, graphs and illustrations were made by:



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store