Learning Mesh-Based Simulation with Graph Networks
By: Tobias Pfaff, Meire Fortunato, Alvaro Sanchez-Gonzalez, Peter W. Battaglia
Overview
A Graph G can be defined as G={N,E}, where N is the set of nodes representing graph elements and E shows the set of edges, representing the relationship between the graph elements N[1]. In addition to topological information, node feature (e.g., attributes of a user in a social network or physical properties in a molecular graph) or edge feature(weights or distances in a graph), represented in vector form can be embedded in Graph.
Tasks performed by Graph Neural Networks can be categorized into Node Classification, Link Prediction, Graph Classification, Community Detection and Anomaly Detection[3]. And MESHGRAPHNETS is basically node regression, which is close to node classification conceptually[4].
From solving Maxwell’s equations to model electromagnetic fields to solving the Navier-Stokes equations for modelling fluid motion and transport phenomena, many complex physical systems modelling use mesh representation to solve the partial differential equations. However, these methods are very computationally expensive and very slow.
Although meshes have been really popular in areas like geometry processing but most of previous deep learning based simulation methods mainly focused on grids mainly because of deep learning methods like CNN’s hardware support. But these methods lead to discretization error and take a lot of computational time. Also, it can be seen from fig. 3 that if we have regular grid like structure it is very difficult to have high resolution edges around the corners as compared to irregular meshes. Because of this mesh based simulation methods like FEM are popular because they provide the flexibility to adjust resolution throughout the simulated space according to the geometry of the modelled object.Additionally, resolution-independent methods are crucial, and Graph Neural Networks (GNNs) emerge as an ideal deep learning framework meeting these criteria for modelling physical phenomena.
Message Passing
Message passing, which is the central method of information learning in GNN, refers to the communication of information between nodes in a graph. The key concept is that 1- each node in graph compute a message for its neighbours. 2- These messages are sent to the neighbours, and each node aggregates the message received from their neighbours in permutation-invariant fashion. 3- Then based on the aggregated messages and previous attributes, each node updates its attributes.
These steps are repeated in multiple iterations, allowing nodes to exchange information with increasingly distant neighbours and refine their representations. By iteratively passing messages between nodes, a GNN can learn to capture complex relationships and dependencies within the graph structure.
Dataset
Datasets from four experimental domains were used. Each dataset consists of 1000 training, 100 validation and 100 test trajectories, each containing 250- 600 time steps. Meshing can be either regular, i.e. all edges having similar length, irregular, i.e. edge lengths vary strongly in different regions of the mesh or dynamic, i.e. change at each step of the simulation trajectory. For Lagrangian systems, the world edge radius rW is provided. Our model operates on the simulation time step ∆t.
Each node of the dataset was encoded with information like Node type, Mesh topology, Node attributes explaining information like where the node is in the simulating space, corresponding velocities, information about the neighbouring nodes and the info. about nodes (in case of CylinderFlow, pressure and 2D velocities of the fluid) that we want to train the MESHGRAPHNETS.
Task Setting
The state of the system at time t is encoded using a simulation mesh M^t = (V,E^M) with nodes V connected by mesh edges E^M as explained in the Dataset.
The task is to learn a forward model of the dynamic quantities of the mesh at time t+1 given the current mesh M^t and (optionally) a history of previous meshes {M^(t-1),…,M^(t-h)} .
Forward Model means that we take one state of the system and produce next state of the system and once its trained, during the test phase you can provide these states to model over and over again to get the whole rollout of the simulation.
Loss and Metric
Figure 5 shows that MESHGRAPHNETS consists of an Encoder, Processor and Decoder, each consisting of MLP layers(input layer) with 2 hidden MLP layers and a output layer, with ReLu activation and normalized by LayerNorm (except output layer). To make our model robust to rollouts of hundreds of steps, random normal noise was added. Model was trained on a single v100 GPU with the Adam optimizer for 10M training steps. With an exponential learning rate decay from 10−4 to 10−6 over 5M steps. RMSE was used as loss function.
Structure of MESHGRAPHNETS
Before understanding MESHGRAPHNETS, we have to understand GNS as the architecture is same in both these papers but new model extends this framework to work with meshes.
GNS and MESHGRAPHNETS both have Encoder-Processor-Decoder architecture.
Fig. 6 shows, In encoder, we make nodes out of the water particles and then connect the neighbouring nodes with edges. The encoder then embeds all of these features independently for every node and every edge in the graph, these features are grouped together into feature vectors and these feature vectors are stored on the graph nodes and graph edges. Now we have a graph, on which we can use GNN.
In Processor, it performs several rounds of message passing so in each of these rounds each of the nodes in the graph computes a message for each of its neighbours and then pools all of the incoming messages from the edges to update the node information and as a result the node and edge embeddings are updated at each step using local information of the neighbourhood and this is pretty standard graphnet. This message passing step is done 5–20 times for several of these graph net blocks.
In decoder, we use the updated representations at each node to decode particle accelerations which are then given to an Euler integrator to produce the next state or you could just predict the absolute position at the next time step.
Fig. 5 shows that for MESHGRAPHNETS, in encode process, there are some key differences. The most obvious difference is that now the input state is no longer a particle set but a mesh , which is already a graph. So we don’t have to make graph here. But the mesh nodes may not be optimal for message passing anymore.
In real world the two nodes are quite close to each other but in mesh space( in 2D) they are quite far away from each other, which makes it inefficient for message passing.
In processor, the model does the same as before except we do message passing in both spaces, so the metal model separately pools world space and mesh space. In decoder, we extracts acceleration for each particle of the mesh which is then used to update the state with an Euler integrator.
Results
MESHGRAPHNETS was compared with two baseline models GCNs [7] and CNN(UNET) and it performed better than both of them as GCN lacks ability of message computation on edges and UNET inadequately sampled the crucial wake region around the wingtip in case of AIRFOIL.
In comparing our approach to GNS on the FLAGSIMPLE fixed-mesh dataset, mesh-space embedding and message-passing seemed really useful. The absence of a cloth’s resting state in GNS leads to considerable error accumulation, resulting in simulation instability, with marginal improvement observed by incorporating a 5-step history.
On increasing graph net blocks, which corresponds to the number of message passing steps, typically enhances performance, albeit at a greater computational expense. It was observed that a value of 15 strikes a favorable balance between efficiency and accuracy across all the examined systems.
It was also observed that model perfomed the best when lowest possible history is given otherwise it was noticed that as we increased history steps, it lead to more and more overfitting.
Conclusion
MESHGRAPHNETS are general-purpose mesh based simulator , consistently generating high-quality rollouts across various domains. Outperforming both particle- and grid-based baselines, MESHGRAPHNETS demonstrate stability, accuracy, and strong generalization to larger and more complex settings during testing. Notably, it operates at a good efficiency, being 10–100 times faster than the ground truth simulator. Its ability to efficiently model a diverse range of physical systems underscores its versatility and effectiveness in simulations.
References:
[1] : https://en.wikipedia.org/wiki/Graph_theory
[2]:https://www.cs.emory.edu/~cheung/Courses/253/Syllabus/Graph/intro.html
[3] : https://towardsdatascience.com/graph-neural-networks-with-pyg-on-node-classification-link-prediction-and-anomaly-detection-14aa38fe1275
[4] : https://medium.com/stanford-cs224w/learning-mesh-based-flow-simulations-on-graph-networks-44983679cf2d
[5] : Neural Message Passing for Quantum Chemistry, Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl
[6] : A. Sanchez-Gonzalez, J. Godwin, T. Pfaff, R. Ying, J. Leskovec, and P. W. Battaglia. Learning to simulate complex physics with graph networks. http://proceedings.mlr.press/v119/sanchez-gonzalez20a.html.
[7] : Semi-Supervised Classification with Graph Convolutional Networks Thomas N. Kipf, Max Welling