A Graph Machine Learning Approach to Synthesizing Raman Spectra

How to construct your own PyG (graph) dataset from scratch, build your own Graph ML model, and train and test the model. By Gordon Downs and Gabe Mudel.

As our course project for Stanford’s CS 224W, we computed theoretical Raman spectra by applying a graph ML model (based on SchNet) to crystal structures of minerals. That is a sentence with a lot of jargon, so let’s unpack it. In particular, let’s define crystal structures, Raman spectroscopy, and graph ML. Then, let’s put it all together in this Google Colab and see the results!

Crystal Structures

Matter is made of atoms, and there is a lot of regularity in the way those atoms arrange themselves. Without getting into too much detail, a common way that atoms arrange themselves is in lattices of repeated arrangements. For instance, diamonds are made up of carbon atoms in a grid, like this:

The crystal structure of diamond. Interact with the 3D model yourself using JSmol via AMCSD!

But most things in lattices have more complicated repeated units, like the mineral fayalite (chemical formula Fe₂SiO₄):

The crystal structure of fayalite. Interact with the 3D model yourself using JSmol via AMCSD!

Here, every different element (Fe, Si, and O) is represented with its own color.

The parallelepiped in each image above is that mineral’s “unit cell.” When we mathematically model minerals, we typically imagine that the unit cell is repeated infinitely and in all directions.

In this project, we parsed Crystallographic Information Files (CIFs) so that we could work with them in bulk. The details can be found in our Colab, but essentially, we just use the package PyCIFRW to extract from the CIFs the locations of the atoms in the unit cell. We have to convert the atomic coordinates, which are in direct space¹, to cartesian space for the interatomic distances to be in angstroms (10⁻¹⁰ meters).

We can infer many properties of minerals through their crystal structures. In this post, we show how crystal structures can be used alongside graph ML to compute theoretical Raman spectra. But first, what are Raman spectra?

Raman Spectroscopy

Raman spectroscopy is a technique that relies on the inelastic scattering of photons — a quantum mechanical effect. Essentially, we shine a monochromatic laser on a sample, and we look at the inelastic scattering that occurs in response.

Credit: makeagif.com

Since the scattering is inelastic, energy is lost and the wavenumber (1/wavelength) of the scattered light is different from the wavenumber of the monochromatic light source. This difference in wavenumber is called Raman shift.

We usually visualize Raman spectra as X, Y pairs with Raman shift (in cm⁻¹) on the X-axis and intensity (in arbitrary units) on the Y-axis, like this:

Raman spectrum of fayalite. Credit: RRUFF

Every species of mineral has a different Raman spectrum, so the Raman spectrum is like a fingerprint for the mineral. We can use Raman spectra not only to identify materials, but also to measure physical properties like temperature.

In this project, we use Raman spectra from the RRUFF database, the gold standard database in the Raman spectroscopy research community. And since different monochromatic light sources produce different Raman spectra, we simplify the problem by only considering Raman spectra collected using 532 nm lasers. Here is a code snippet that gives us the Raman spectra we will use².

Raman spectra are typically very hard to compute, and most existing software relies on expensive quantum calculations. But, with graph ML, we can frontload that computation (by training our model) and then compute theoretical Raman spectra cheaply! Let’s see how to do that.

Graph Machine Learning

Not every ML problem can be formulated as making a prediction with a grid or sequence of numbers. Graph ML allows us to learn interesting properties about the structure of a much more flexible type of data: graphs.

Graphs are a particularly powerful way to represent structures that have some form of interconnected structure — social networks, the internet, and (as you may have guessed) atomic structures like crystals are all very nicely represented as graphs! At a high level, our goal is to create graph representations of our crystals, then run a graph neural network (GNN) on these graphs and predict their Raman spectra.

To be able to run a GNN on our data, though, we first need to decide exactly how to convert our crystal structures into graphs. In our case, we choose to represent a crystal by the graph of its asymmetric unit. Without getting into too much technical detail, the asymmetric unit of a crystal is the minimum collection of atoms that gets repeated to form said crystal³. Thus, our final graph representation is the crystal’s asymmetric unit, with nodes representing atoms and edges representing interatomic distances between atoms. We take advantage of the very useful NetworkX library to build our graphs as shown in the snippet below (lines 12, 22, and 41)⁴:

After doing this, we’ll want to convert our crystal structure data into a format compatible with PyTorch Geometric (PyG), then store it for later use. Fortunately, PyG makes this a breeze! The following code pairs up our graphs with their corresponding Raman spectra:

Finally, we save our dataset to disk:

PyG’s collate simply combines our many graphs into one big graph.

After completing this step, our dataset is formatted for easy loading in the future!

Model

Equally important to our success is choosing an appropriate model for the task. The particular graph ML model we use is a variation of SchNet, which is a GNN designed for modeling quantum mechanical interactions between atoms — just like what we see in Raman spectroscopy [1]. SchNet makes use of continuous-filter convolutions. At a high level, these are generalizations of typical convolutional layers that work with arbitrarily spaced inputs of varying quantity (typical convolutions operate on fixed-space inputs of fixed quantity). With this architecture, we’re able to capture information about crystals through properties such as bond lengths — something clearly very relevant to our use case.

Our full model architecture modifies SchNet slightly: rather than outputting a single scalar value, we output a vector in ℝⁿ. After feeding our graphs through SchNet, we apply an MLP to predict our Raman spectra. For our purposes, we chose a Raman “resolution” of 266 points, with fixed equidistant values along the Raman Shift (X) axis. This means that our model outputs corresponding values along the Intensity (Y) axis for each point along the Raman Shift axis. This brings us to the model architecture shown below:

Note that PyG’s Sequential module is slightly different from PyTorch’s.

Tying it all together with PyG

Finally, with both our dataset and our model, we can use PyTorch and PyTorch Geometric (PyG) to tie up loose ends and train our model! First, we define our optimizer and loss. We choose Adam and mean squared error (MSE) for these, respectively — very standard choices for a regression task like the one we are performing.

Next, we leverage PyG’s InMemoryDataset and DataLoader classes to make interacting with our data as painless as possible. Because this is a graph regression task, we are fortunate to be able to treat each graph as an independent data point, and we don’t have to “cut” edges between our nodes⁵. We create train/validation/test splits using 70, 15, and 15 percent of our data, as per standard ML practice.

With this dataset, we’re able to create a DataLoader object that wraps around it, allowing us to easily retrieve a crystal structure and its corresponding Raman spectrum.

Finally, we can train our model! We write a basic training loop following the skeleton code below:

For brevity, we omit validation and model checkpointing in the code above.

It’s worth mentioning that this loop is simple — just like an ordinary PyTorch training loop. We’re able to combine PyG layers with vanilla PyTorch layers and backpropagate through them smoothly. By setting up our model backend in PyG, we don’t have to do any extra, messy gradient operations for the GNN component of our model!

Results

We add code to perform validation after every epoch and plot the results in Tensorboard.

X-axis is epoch number, Y-axis is loss
X-axis is epoch number, Y-axis is loss

How exciting — our model is learning! After several thousand epochs, we plot some spectra below:

We can clearly see that some predictions are more accurate than others. That being said, the loss curve and the better-performing examples clearly illustrate that the model did indeed learn something meaningful about the relationships between crystal structures and Raman spectra in our dataset.

Future Work

There is great potential for future work in applying graph ML techniques to spectroscopy! Here are some potential avenues for further progress:

  1. Include more features. SchNet as we used it only took in atomic numbers (essentially ids) and their corresponding positions as features. The dataset we used includes much more rich data for each crystal structure⁶, and we surmise that these data can be leveraged to create an even better model.
  2. Go the opposite direction! Determining a material’s structure from only its Raman spectrum is considered by the mineralogy community to be an intractable problem⁷. But, there is potential to achieve this with a generative model such as GraphRNN.

Footnotes

¹ Direct space is a coordinate space used in crystallography. Coordinates in direct space are fractional coordinates, and the unit cell parameters are needed to convert to cartesian space where coordinates have units of angstroms.

² apc is a flexible package for reading mineralogical datasets.

³ More formally, an asymmetric unit is “the smallest fraction of the unit cell that can be rotated and translated using only the symmetry operators allowed by the crystallographic symmetry to generate one unit cell.” This definition is from https://www.sciencedirect.com/topics/chemistry/asymmetric-unit. You can also read more about asymmetric units and their properties by following that hyperlink.

⁴ Note that this snippet (and further snippets) are incomplete subsets of the code used to perform the task at hand — they won’t work out-of-the-box and are only present to aid in understanding. Our full codebase can be found at https://github.com/gordondowns/cs224w-project/.

⁵ For node- and edge-level graph ML tasks, splitting the dataset is more cumbersome because often you only have one very large graph.

⁶ Other potentially useful atom-level features in the dataset include electronegativity, isotropic B-factor, and atomic weight, among others.

⁷ If this Medium post inspires you to solve this problem, please consider citing us and giving us 1% of your Nobel Prize money.

References

[1] Schütt, K. T., Kindermans, P.-J., Sauceda, H. E., Chmiela, S., Tkatchenko, A., & Müller, K.-R. (2017). SchNet: A continuous-filter convolutional neural network for modeling quantum interactions. arXiv [stat.ML]. Opgehaal van http://arxiv.org/abs/1706.08566

--

--