Using MLflow to deploy Graph Neural Networks for Monitoring Supply Chain Risk

Ajmal Aziz
14 min readSep 13, 2022

--

This blog article builds a Lakehouse for supply chain intelligence and monitoring. It demonstrates streaming ingestion, data engineering, training and deploying a Graph Neural Network (GNN) using MLflow on Databricks, and using inference results from the GNN to build an executive-level dashboard. The dashboard (shown below) provides a high-level view of a company’s supply chain risk so that executives and supply chain managers can have a real-time understanding of the overall risk exposure from a production, geographic, and ESG perspective.

This blog productionises a paper published during the author’s time at the University of Cambridge. You can find a copy of the paper here. The notebooks used can be found here.

Completed supply chain risk dashboard on Databricks SQL.

Modern Supply Chains as Complex Networks

We live in an ever interconnected world, and nowhere is this more evident than in modern supply chains. Due to the global macroeconomic environment and globalisation, modern supply chains have become intricately linked and weaved together. Companies worldwide rely on one another to keep their production lines flowing and to act ethically (e.g., complying with laws such as the Modern Slavery Act). From a modelling perspective, the procurement relationships between firms in this global network form an intricate, dynamic, and complex network spanning the globe.

Whilst these networks have proved useful for labour arbitrage for cost mitigation, and made goods more easily accessible, they have also introduced fragility and hidden vulnerabilities in global value chains. Production shocks to key companies can have far-reaching consequences as they propagate through these networks. The network complexity may also hide production, ESG, and other structural risks for retail and manufacturing firms. This article from McKinsey further explores and characterises these risks and shocks companies can be vulnerable to. Analysis from McKinsey has also estimated that companies can expect to lose 40% of their year’s profits every decade on average, and depending on the severity of these shocks they can wipe out an entire year’s worth of earnings in some industries.

These risks are generally hidden since most companies consider only their immediate neighbourhood of nodes within this complex network or their tier 1 suppliers (captured within the dotted lines in the figure below). Companies emphasise tier 1 suppliers, but this exposes firms to significant amounts of risk due to a lack of transparency over the extended network. There have been instances where companies have suffered economic or reputational damage due to their second or third-tier suppliers performing allegedly nefarious activities.

Arrows between suppliers represent buying and selling relationships. Firms typically see one layer of this network whilst there is a web of other suppliers (and some that are potentially high risk) that are invisible.

Building Resilient Supply Chains Starts with Data

The first step to gaining supply chain resilience is visibility over this hidden extended network since visibility can lead to mitigating actions. Unfortunately, it is typically infeasible to discover a company’s entire supply chain network manually. This is due to the fact that the network is rapidly changing in real-time, with hundreds of thousands of connections spanning the globe, and with competitors and other companies in the network under no obligation to reveal this information.

Academic literature proposes web document retrieval to obtain a preliminary representation of a company’s extended supply chain network. This method serves as a great starting point and we can fill in any potential missing gaps by leveraging purpose-built machine learning algorithms which learn from graph-structured data.

As a quick taste of pattern recognition in this domain, in the figure below, an inferential question would be whether or not Company A has a procurement relationship with Company B. Intuitively, from the community structure evident in the graph, a connection between these two nodes should be assigned a low likelihood. In this post, we will train algorithms that perform this type of pattern recognition for us at scale by making use of an exciting area of machine learning research — graph representation learning.

A real Supply Chain Network from the automobile sector. Nodes represent companies, and edges represent known procurement relationships (image retrieved from here).

Just Enough Theory and Context

Note: This section provides just enough theory to be able to understand the implementation. Readers familiar with graph representation learning theory and GNNs can skip this section.

There are broadly three approaches we can take to tackle the relation prediction problem we have outlined. They can be categorized as follows:

  1. Traditional approaches (non-machine learning): involves manually computing graph statistics like neighbourhood overlaps, katz measure, etc. (see here for further traditional techniques).
  2. Traditional approaches + machine learning: requires manually building feature representations (or featurising) the network so that they can be used as inputs for traditional supervised learning algorithms.
  3. Graph representation learning: using a data-driven approach to learn low-dimensional embedding vectors for nodes which can be used by downstream learning algorithms.

Although there are merits to each approach, graph representation learning is the most performant and state-of-the-art approach for tackling learning tasks over the graph domain. This is due to the fact that traditional approaches require us to manually create hand-engineered features which do not adapt through a learning process, can be time consuming, and require subject matter expertise (see here for further details). Additionally, if we featurise the network for traditional supervised learning algorithms, in many cases, we lose the rich structural information evident in the connectivity patterns in the network.

Displaying 3 levels of structural complexity in data types for deep learning.

We will utilise Graph Neural Networks (GNNs) to generate the node embeddings for our graph. GNNs are a general framework for defining deep learning algorithms over graph-structured data. Historically, deep learning algorithms have benefited from strong geometric assumptions, namely, that the underlying data sit in the grid domain, either spatial or temporal. The image above shows the evolution of structural complexity in data for deep learning. Traditional deep learning approaches apply well for linear sequence-based data and grid data where structural assumptions can be made, however, these assumptions do not apply to more complex graph structured data. With GNNs, we apply a more generalised form of a convolution described in the next paragraph. For a more detailed exposition on this, see here.

Embedding vectors are generated by GNNs by defining a computation graph for each node in the graph based on its local connectivity structure (see image below). Intuitively, one can think of this step as aggregating both feature and structural information for each node’s local neighbourhood. This is achieved by learning local functions that can be applied to all nodes in the graph which aggregate feature information and result in embedding vectors. This is more formally referred to as neural message passing where the messages are node features.

When generating embeddings, the computation graph is defined based on the local connectivity structure of a given node (image source).

Once embeddings have been obtained for nodes in the graph we can use them to solve problems that can be specified over graphs. For example, tasks can be split into regression and classification over nodes, edges, or the entire graph depending on the problem. Our reasoning task is to learn from structural information in the supply chain network to locate links (or procurement relationships) that are not captured, also referred to as link prediction.

Lastly, it was mentioned earlier that GNNs are a framework for defining deep learning algorithms over graph structured data. For this blog, we will utilise a specific architecture of GNNs called GraphSAGE. This algorithm does not require all nodes to be present during training, is able to generalise to new nodes efficiently, and can scale to billions of nodes. Earlier methods in the literature were transductive, meaning that the algorithms learned embeddings for nodes. This was useful for static graphs, but the algorithms had to be re-run after graph updates such as new nodes. Unlike those methods, GraphSAGE is an inductive framework which learns how to aggregate information from neighborhood nodes; i.e., it learns functions for generating embeddings, rather than learning embeddings directly. Therefore GraphSAGE ensures that we can seamlessly integrate new supply chain relationships retrieved from upstream processes without triggering costly retraining routines. For more information on GraphSAGE, see the original paper on inductive representation learning for large graphs.

Reference Solution Architecture

We’re going to assume that a company has set up a web document retrieval method to obtain preliminary buying and selling relationships between firms. We will incrementally ingest this data from cloud storage using Auto Loader. This emulates how a retrieval system would store this raw collected data in cloud storage. As the files are incrementally ingested into the Delta Lakehouse, we also emulate having to clean this collected data through a medallion architecture prior to any BI or machine learning.

We then incorporate static data from our supply and finance teams that have structured information about country risk profiles as well as company risk scores (in CSV). These tables inform the dashboarding layer. The GNN serves as a refinement method when going from silver to gold tables. At inference time, low confidence links in the Silver tables with which the GNN was not trained are fed to the GNN and the probability scores are compared. If the GNN outputs a higher probability than the retrieval system then these links are re-introduced in the Gold tables. Finally, the gold tables are used in the dashboard using Databricks SQL to provide a holistic view of the company’s supply chain network and potentially identify any associated risks.

Overall architecture using Mlflow for GNN training and deployment. The basis is a medallion DE architecture with final gold tables served to DBSQL.

Data Engineering

There are two data engineering steps required for refining our Bronze tables to Silver:

  1. Our web document retrieval data contains a confidence score and we do not want to train our GNN on low probability links so we will filter links below a confidence threshold. We choose a confidence level of 0.55.
  2. We notice that the raw data company names have postfixes (Ltd., LLC, etc.) denoting different legal entities but pointing to the same physical company. We need to consolidate these companies so that we do not duplicate nodes in the network. We leverage cleanco, an open-source library for this.

To complete both of these steps, we register a spark UDF and apply it to the Auto Loader stream as follows:

Now the silver table is primed for machine learning. We assume that the static tables from the risk and finance departments are ready to be used as they are.

Implementation of the Learning Process: The Neural Network Architecture

The overall GNN architecture is shown below. The graph learning algorithms in this blog make use of topological information (connectivity structure) as well as node features (red boxes in the figure below). We randomly initialize the companies’ node feature vectors. These vectors are interchanged between nodes and form the basis for the embeddings after the GNN model has been trained. These are shown below in orange.

Our GNN model will consist of two GraphSAGE layers to generate node embeddings. The embeddings are then fed into a separate fully connected neural network that will take the embeddings for source and destination nodes as inputs and provide a prediction for the likelihood of a link (i.e. binary classification). More formally, the neural network acts as:

The function of the final fully-connected layer where the inputs are latents associated with nodes under consideration.

where the inputs are latent representations of the nodes under consideration and the output is a likelihood of whether or not a link should exist between the node pairs. All of the network weights, including the weights from the GNN and the fully connected MLP layer are trained using a single loss function. The two losses considered are binary cross-entropy loss and margin loss — the loss function that yields the best results on validation data is chosen. Lastly, the Deep Graph Library (DGL) library with a pytorch backend is used for training.

Forward pass architecture used in this example. Orange rectangles represent node features, and yellow rectangles represent latents.

The code for creating the model definition is shown below. The class extends the PyTorch neural network module with a few key modifications. Namely, a forward pass through the model class involves a graph convolution with a custom defined class called GraphSAGE. In the forward pass, the embeddings are fed through a simple MLPPredictor where the node embeddings are stacked and the output is a single score. The three key class definitions are shown below beginning with the GraphSAGE class, the MLPPredictor class, and the final Model class that cascades the two other modules.

Delta tables for relational data

Before diving into the specifics of training this model, let’s consider the data. The full dataset consists of company relationships stored in delta. This will have performance benefits as well as benefits like time travelling which means that we can keep track of version histories of our collected supply chain graph. Since we are using dgl for training our graph neural network models, we define a function that converts a delta table to a DGL graph as follows:

Training Scheme for Link Prediction

We will train our GNN in the mini-batch setting since this format scales well as the number of nodes in the graph grows. This method of training is also referred to as stochastic training of GNNs. We will leverage a graph data loader class from the Deep Graph Library (DGL) library to facilitate the network training process. This library provides data loaders specific to GNN training. We are able to interact with these data loaders in a similar fashion to PyTorch data loaders. Graph data loaders in graph representation learning are task-dependent. For link prediction, the GNN is trained on edge batches. If the task were node focused (e.g. node classification or regression) then the graph would be partitioned based on nodes and we would use a node data loader.

The edge data loaders offer a generator interface for sampled edge bunches from our training, validation, and testing graphs. Edge batches contain positive edges, which are edges that are observed, and negative edges. Negative edges are node pairs that do not exist in the original graph. The learning algorithm is then assessed on how accurately it is able to distinguish between positive and negative edges.

The graph is partitioned according to edges. The graph partitions are then split into positive (real) and negative (non-existent) edges. Negative edges are shown in red.

Let’s demonstrate below how to create these edge data loaders. The following function takes a dictionary of graph partitions as inputs and returns edge data loaders with a fixed negative sampling scheme. In the function below, we draw negative edges from a uniform distribution. The negative sampling scheme is an important design choice, but for the sake of demonstration, we have chosen a relatively simple method where we sample edges from a uniform distribution.

The trainable network weights (including the GNN layers) are trained via stochastic gradient descent. The loss is calculated by assessing the model’s capability to distinguish between positive graphs and negative graphs given their node features and connectivity structure. These sub-graphs are generated from the edge data loaders. In the code snippet below, a perfect model would return a tensor of 1s for pos_score and 0s for neg_score.

Distributing hyperparameter tuning using HyperOpt

Although we can leverage backpropagation for the network weights, there are still a number of design choices that are key to the success of this learning problem. Some variables include the size of the node features, the number of negative samples, and even the aggregator type in the GraphSAGE layers which refers to a permutation invariant function for aggregating node features.

To search through the design space we leverage HyperOpt, a widely adopted open-source framework for distributed parameter tuning. HyperOpt uses Bayesian optimisation to search through the design space for global optimisation of black-box functions. We can define the function as well as the design space and distribute the search across a Spark cluster with SparkTrials. The article Scaling Hyperopt to Tune Machine Learning Models in Python gives an excellent deep dive into how this works. In our case, we define a function called train_and_evaluate_gnn. This function trains our model for a given configuration of parameters and returns the negative average ROC-AUC across a fixed set of validation edge bunches (or subgraphs). This returned value is the loss that HyperOpt tries to minimise.

Deploying our Model using Mlflow

Once the search is complete and an acceptable set of parameters are determined, we can use Mlflow to register, log, and deploy our GNN model. This demonstrates how flexible Mlflow is as a modelling framework. Mlflow provides a set of pre-defined model flavours for popular (and some obscure) frameworks for packaging models for serving. As of now, there are no out-of-the-box implementations for GNNs, but we can leverage the mlflow.pyfunc module to create a custom GNN model (GNNWrapper in our case) and logic. In this class, we have custom logic for converting tabular data to a graph, initialising node features, and finally generating new node embeddings using the trained GNN model.

We leverage Mlflow for tracking our model runs and the model registry when using the model for inference.
The t-SNE plot is logged as a model artefact.

Once an Mlflow model flavour has been defined, there are a number of key benefits associated with this. In the code snippet below, we demonstrate logging metrics, parameters, the GNN model, as well as a t-SNE plot as an artefact of the final model run. The t-SNE plot is a great diagnostic tool for whether or not the learned embeddings are useful. The gif above shows the t-SNE plot we log for this model. In this plot, we can see two distinct clusters of embeddings. We can easily imagine that the neural network decision boundary would cut across these embeddings implying a good level of separation and distinction. In this specific run, we see that the model has a training AUC of 0.84, validation AUC of 0.93, and testing AUC of 0.95 which shows good generalisation. We are also happy with the t-SNE plot and will therefore deploy the model to production.

Now that we have tuned our GNN and graph design space, we can deploy our GNN as a UDF for distributed inference using Mlflow. We will query the GNN as to whether or not a low confidence link should indeed be added to the overall network. If the GNN provides a high likelihood for the link, then we will add the link to the silver table thereby creating a gold table. We can achieve this easily using mlflow.pyfunc.spark_udf. The code below shows us getting predictions in a batch setting from our GNN model.

Now in a few lines, we can obtain our gold relations by directly comparing the original probability score in the silver table against the prediction provided by the GNN model.

Final Supply Chain Dashboard

Now, this refined gold table of relationships along with the tables we had originally retrieved from finance can be used to build out a supply chain dashboard. This dashboard would be immensely useful to any company’s C-suite, operational, finance, and risk teams amongst others. It can be used by business units to: view the geographic footprint of the company’s entire supply chain, assess and see how ESG, financial, and production risks are tracked with the most up to date views refined with learning algorithms. This analysis can deliver game changing outcomes by allowing companies to go from being reactive to predictive about their supply chains and demonstrate how to gain a competitive advantage by increasing their Data and AI maturity through the Lakehouse on Databricks.

An example supply chain surveillance dashboard based on risk and finance tables alongside gold tables refined by our GNN model.

What’s Next

Supply chain surveillance, resilience, and visibility are imperatives for modern organisations. We have seen how supply shocks can impact consumers and organisations alike. We have shown in this blog how we can take a predictive and data-driven approach by modelling the underlying system as a network and then applying state-of-the-art learning algorithms to refine the representation. You are welcome to use this as a basis for further supply chain-related use cases or GNN-related implementations.

--

--