Application of GNN for calculating the solubility of Molecule ( Graph Level Prediction)

Tejpal Kumawat
6 min readMar 12, 2023

--

In this blog, we will estimate solubility directly from chemical structures

This will be a regression problem, we will give the Smile string to the model and predict its solubility

We will use Pytorch Geometric for Graph Neural Network and RDkit to handle the Molecular Data.

Install Pytorch, Pytorch Geometric from Pip

Install RDkit from pip install rdkit-pypi

About Dataset

In the following, we’ll make use of a dataset from PyTorch Geometric’s dataset library (Here you find all datasets). The MoleculeNet collection, which contains the Dataset, is available here.

“The ESOL dataset contains information on the solubility of 1128 different chemicals in the water. The dataset was employed to train models that calculate solubility directly from chemical structures (as encoded in SMILES strings). Due to the fact that solubility is a feature of molecules generally, rather than of specific conformers, these structures do not include 3D coordinates.”

Our Task — How does solute dissolve in solvent?

Let’s see the SMILE representation of the Molecule

SMILE String

Load Dataset

import rdkit
from torch_geometric.datasets import MoleculeNet

# Load the ESOL dataset
data = MoleculeNet(root=".", name="ESOL")
data
Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv
Processing...
Done!

Investigating the datset

print("Dataset type: ", type(data))
print("Dataset features: ", data.num_features)
print("Dataset target: ", data.num_classes)
print("Dataset length: ", data.len)
print("Dataset sample: ", data[0])
print("Sample nodes: ", data[0].num_nodes)
print("Sample edges: ", data[0].num_edges)


Result-

Dataset type: <class 'torch_geometric.datasets.molecule_net.MoleculeNet'>
Dataset features: 9
Dataset target: 734
Dataset length: <bound method InMemoryDataset.len of ESOL(1128)>
Dataset sample: Data(x=[32, 9], edge_index=[2, 68], edge_attr=[68, 3], y=[1, 1], smiles='OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O ')
Sample nodes: 32
Sample edges: 68

We can see there are 9 features per node of the graph and we have total of 734 graphs. We have a target dimension of size 734 means for every graph we have some solubility (numerical ) value in the solvent.

# Investiagte the features of the node of graph 
data[0].x

# Investigating the edges in sparse COO format
# Shape [2, num_edges]
data[0].edge_index.t()

# See the target value of data[0]
data[0].y

We will perform predictions based on the graph level. This means we have one y-label for the whole graph, as shown in the figure

Task

Converting SMILES to RDKit molecules — Visualizing molecules

We want to have our SMILES molecules as graphs…

data[0]["smiles"]

# Result SMILE string

OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)C(O)C3O
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
molecule = Chem.MolFromSmiles(data[0]["smiles"])
molecule

We have seen how to visualize our SMILE string into original molecule

How to use RDkit to extract features from the molecule

It provides all the information we require, such as edges, atom properties (type,…), etc.
Nevertheless, in our instance, it’s even simpler because the dataset already contains the information that’s been expressly provided.
Instead, we would use those atom attributes to calculate the node features.

Implementation of Graph Neural Network

The process for creating a Graph Neural Network is very similar to that of a Convolutional Neural Network; we simply add more layers.

The GCN simply extends torch.nn.Module. GCNConv expects:

  • in_channels = Size of each input sample.
  • out_channels = Size of each output sample.

We use three convolutional layers, which means that we gain knowledge of three neighbor hops. In order to do graph-level prediction, we then apply a pooling layer to mix the data from the individual nodes.

We will use pytorch and pytorch geometric for our task

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
embedding_size = 64

class GCN(torch.nn.Module):
def __init__(self):
# Init parent
super(GCN, self).__init__()
torch.manual_seed(42)

# GCN layers
self.initial_conv = GCNConv(data.num_features, embedding_size)
self.conv1 = GCNConv(embedding_size, embedding_size)
self.conv2 = GCNConv(embedding_size, embedding_size)
self.conv3 = GCNConv(embedding_size, embedding_size)

# Output layer
self.out = Linear(embedding_size*2, 1)

def forward(self, x, edge_index, batch_index):
# First Conv layer
hidden = self.initial_conv(x, edge_index)
hidden = F.tanh(hidden)

# Other Conv layers
hidden = self.conv1(hidden, edge_index)
hidden = F.tanh(hidden)
hidden = self.conv2(hidden, edge_index)
hidden = F.tanh(hidden)
hidden = self.conv3(hidden, edge_index)
hidden = F.tanh(hidden)

# Global Pooling (stack different aggregations)
hidden = torch.cat([gmp(hidden, batch_index),
gap(hidden, batch_index)], dim=1)

# Apply a final (linear) classifier.
out = self.out(hidden)

return out, hidden

model = GCN()
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))
# OUTPUT
GCN(
(initial_conv): GCNConv(9, 64)
(conv1): GCNConv(64, 64)
(conv2): GCNConv(64, 64)
(conv3): GCNConv(64, 64)
(out): Linear(in_features=128, out_features=1, bias=True)
)
Number of parameters: 13249
  • Since we have huge molecules, we use 64 embeddings rather than the option to minimize them.
  • We gain more knowledge about the graph as we add more levels.
    We utilize a linear layer as the final output layer for the regression problem.
  • Although we only have about 1,000 samples, we strive to use as few parameters as possible.

Training of GNN

from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

# Root mean squared error
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)

# Use GPU for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Wrap data in a data loader
data_size = len(data)
NUM_GRAPHS_PER_BATCH = 64
loader = DataLoader(data[:int(data_size * 0.8)],
batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(data[int(data_size * 0.8):],
batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

def train(data):
# Enumerate over the data
for batch in loader:
# Use GPU
batch.to(device)
# Reset gradients
optimizer.zero_grad()
# Passing the node features and the connection info
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
# Calculating the loss and gradients
loss = loss_fn(pred, batch.y)
loss.backward()
# Update using the gradients
optimizer.step()
return loss, embedding

print("Starting training...")
losses = []
for epoch in range(2000):
loss, h = train(data)
losses.append(loss)
if epoch % 100 == 0:
print(f"Epoch {epoch} | Train Loss {loss}")
# OUTPUT
Starting training...
Epoch 0 | Train Loss 3.377596378326416
Epoch 100 | Train Loss 0.9617947340011597
Epoch 200 | Train Loss 1.0771363973617554
Epoch 300 | Train Loss 0.6295697093009949
Epoch 400 | Train Loss 0.37517455220222473
Epoch 500 | Train Loss 0.465716689825058
Epoch 600 | Train Loss 0.5129485726356506
Epoch 700 | Train Loss 0.21677978336811066
Epoch 800 | Train Loss 0.33871856331825256
Epoch 900 | Train Loss 0.3640660345554352
Epoch 1000 | Train Loss 0.20501013100147247
Epoch 1100 | Train Loss 0.18023353815078735
Epoch 1200 | Train Loss 0.2812242805957794
Epoch 1300 | Train Loss 0.18207958340644836
Epoch 1400 | Train Loss 0.1321338415145874
Epoch 1500 | Train Loss 0.18665631115436554
Epoch 1600 | Train Loss 0.1817774772644043
Epoch 1700 | Train Loss 0.09456530958414078
Epoch 1800 | Train Loss 0.23615044355392456
Epoch 1900 | Train Loss 0.11381624639034271

Visualize Training loss

# Visualize learning (training loss)
import seaborn as sns
losses_float = [float(loss.cpu().detach().numpy()) for loss in losses]
loss_indices = [i for i,l in enumerate(losses_float)]
plt = sns.lineplot(loss_indices, losses_float)
plt

Prediction of Test Data

import pandas as pd 

# Analyze the results for one batch
test_batch = next(iter(test_loader))
with torch.no_grad():
test_batch.to(device)
pred, embed = model(test_batch.x.float(), test_batch.edge_index, test_batch.batch)
df = pd.DataFrame()
df["y_real"] = test_batch.y.tolist()
df["y_pred"] = pred.tolist()
df["y_real"] = df["y_real"].apply(lambda row: row[0])
df["y_pred"] = df["y_pred"].apply(lambda row: row[0])
df

Let’s visualize the y_pred and y_original

plt = sns.scatterplot(data=df, x="y_real", y="y_pred")
plt.set(xlim=(-7, 2))
plt.set(ylim=(-7, 2))
plt

That’s it, this is our first application on GNNs.

Reference

--

--

Tejpal Kumawat

Artificial Intelligence enthusiast that is constantly looking for new challenges and researching cutting-edge technology to improve the world !!