Locomotion Reconstruction with the Heterogeneous Graph Transformer
By Yiwen Dong, Samuel Hunter, Jin Liu as part of the Stanford CS224W course project
Summary
Gait health monitoring is important in providing diagnosis and treatments for patients with musculoskeletal or neuromuscular disorders [1] [2] [3] [4] [5]. Many existing studies have shown that gait health monitoring is critical for diagnosis and rehabilitation of diseases such as dementia, muscular dystrophy, stroke, and joint injuries [6] [7] [8] [9] [10].
To monitor gait health, motion capture cameras are widely used and promising in providing accurate kinematic information when operating well-calibrated gait clinics. However, motion capture systems often have missing data due to the occlusion of reflective markers and noise from the environment, and their accuracy relies on the precision of the marker locations on the body, which are often empirical and vary among different subjects. In addition, post-processing is also time-consuming, requiring additional biomechanical simulation software to obtain gait health status including joint angles and forces. As a result, it typically takes a medical specialist more than one week to process the data and prepare gait health reports.
In this project, we propose direct locomotion reconstruction from spatial and temporal motion capture data using the Heterogeneous Graph Transformer (HGT) [11], a type of Graph Neural Network (GNN) specialized for handling heterogeneous graph structures. HGT is suitable to address the missing/imprecise marker problem because it allows modeling of both observed and unobserved markers, and then leverages observed motion to predict the unobserved ones. Our approach aims to reduce the time for clinical gait analysis and produce accurate prediction of joint positions. By modeling the spatio-temporal dependencies in lower-body locomotion during walking, our approach takes in measurements of body movements from existing subjects and predicts the location of the missing markers to aid with estimating critical gait parameters for health monitoring.
Motion Capture Dataset
The dataset was collected by Yiwen Dong from Stanford Structures as Sensors Lab using Vicon Motion Capture System. The dataset contains 22 participants (ages 18 to 40 years old) walking across a 7 meter long walkway, repeated 40 times back and forth. The experiment setup includes 10 infrared cameras aiming at the subject to capture the motion of lower limbs, recording at a sampling rate of 100 Hz.
For each walking trial, 16 markers are attached to the subject’s lower limbs, producing (x, y, z) coordinates of locomotion. The gait events are manually labeled by researchers in the Structures as Sensors Lab. The labels include “foot strike” and “foot off”, indicating beginning and the end of the stance phases. A complete gait cycle is extracted from a “foot strike” to the next “foot strike” of the same leg. Figure 2(a) shows the walking experiment in-action and Figure 2(b) shows the ankle marker coordinates corresponding to the gait cycles.
However, markers can be missing during the experiment for two main reasons. First, there can be visual occlusion caused by the movement of the arms and unexpected covering from clothes. Second, there’s visual noise in the environment such as reflections from the subject’s wearables, shoes, accessories, and surrounding furniture, resulting in ghost markers captured by the system.
Vicon Nexus exports a .csv for each participant’s trial with multiple tables in one file. The important ones are Trajectories and Events.
Data Pre-processing
Before generating graphs, we needed to preprocess the raw data since each .csv contains multiple gait cycles, and the coordinate system and gait cycle length aren’t standardized. We pre-processed the data with pandas in the following steps:
- Split the left and right marker trajectories
- Split trajectories into gait cycles by events
- Convert absolute positions to relative positions to the pelvis marker
- Resample the gait cycle time into standard percentage values (0–100%)
First of all, a typical gait cycle only involves one leg. The left and right cycles overlap each other, so we split left and right into individual gait cycles and their markers (8 per side).
Next, the positions of the markers are recorded in absolute location from the global coordinates during calibration, which varies for each subject, making it difficult to align between different people. To address this problem, we fixed the centroid of 4 pelvis markers as the origin and computed the relative locations of other markers. This is because it approximates the center of mass, whose trajectory can be easily computed by the averaging absolute movement of pelvis markers.
Also, the duration of each gait cycle varies, making it difficult to align between samples. Therefore, we resampled to a standard 100 timesteps, each timestamp represents 1% in a gait cycle, because it aligns with the clinical standard of gait analysis.
Finally, we transform the columns of markers into rows, mapping marker names to integer indices 0–7 by alphabetical order. Any gait cycle with missing data was omitted from the training data, since we need the ground truth positions for comparison.
After repeating this procedure for all of the participants and trials, there were 5153 cleaned individual gait cycles.
The exact details and code can be found in Clean.ipynb:
https://drive.google.com/file/d/1Wf-DGamiHw21NA2Dq0Y8FG26-YL48mqW/view?usp=share_link
Modeling Lower Limb Locomotion as a Graph
Each node in our graph represents a marker at a given timestamp (there are 8 markers and 100 timestamps in our data), so there are 800 nodes in total. The node attributes are the marker index and time, and the node features are the (x, y, z) position. To simulate unobserved node features (i.e., missing data in the motion capture system), 20% of the nodes will have their positions replaced with NaN. The node targets are the ground truth positions.
There are two types of edges in the graph: spatial and temporal.
- Spatial edges (blue lines) capture the physical connectivity of bones between the joints at the markers. There are 11 edges between the 8 markers, which are hardcoded for a single time step and replicated the nodes of the same markers at all timesteps.
- Temporal edges (black lines) capture the temporal dependencies during the subject’s locomotion. Each marker is connected to itself at the previous and next timestep, except for the final timestep.
Propagation of observed features to initialize the unobserved markers
To initialize the unobserved nodes (nodes with NaN), we apply a preprocessing step known as feature propagation [12]. Feature propagation handles missing features with a diffusion-like operation on the graph. It is deterministic and extremely fast. We use the propagated features as an initial guess for the target positions to feed into the GNN.
PyG provides an implementation torch_geometric.transforms.FeaturePropagation, but is limited to homogeneous graphs. We create two homogeneous graphs with the observed nodes, one with spatial edges only and the other with temporal edges. We run feature propagation on both, and average the node features for our full heterogeneous graph.
def prop_observed_features(node_observed_features, node_attributes, time_edge_index, spatial_edge_index):
transform = FeaturePropagation(missing_mask=torch.isnan(node_observed_features))
time_data_prop = transform(
Data(
x = node_observed_features,
node_attr = node_attributes,
edge_index = time_edge_index
)
)
spatial_data_prop = transform(
Data(
x = node_observed_features,
node_attr = node_attributes,
edge_index = spatial_edge_index
)
)
node_features = (time_data_prop.x + spatial_data_prop.x) / 2.0
return node_features
We combine the node attributes, features, targets, spatial and temporal edges into a heterogenous graph, represented in PyG as torch_geometric.data.HeteroData.
def build_graph(node_features, node_attributes, node_targets, time_edge_index, spatial_edge_index, idx_unobserved, filename):
data = HeteroData()
# Each markers has 2 attribute (marker_id, time)
data['marker'].node_attr = node_attributes
# Each marker has 3 features (x, y, z)
data['marker'].x = node_features
# The target (ground truth position) has 3 features (x, y, z)
data['marker'].y = node_targets
# Marker time edge is the previous timestep to the current timestep
data['marker', 'time', 'marker'].edge_index = time_edge_index
# Marker space edge edges (replicated each timestep)
data['marker', 'space', 'marker'].edge_index = spatial_edge_index
data['marker'].idx_unobserved = idx_unobserved
# Graph name (remove .csv)
data.label = filename[:-4]
return data
We repeat this procedure for all of the 5153 cleaned gait cycles, so we have a List[HeteroData].
Batching
To sample from the list of graphs, we use torch_geometric.loader.DataLoader, which creates mini-batches of HeteroData. Each batch is of type HeteroDataBatch, where all the disconnected graphs in the batch are merged into one, and the batch variable stores the index of which graph each node belongs to. We split the data into 80% training and 20% test. The training data is loaded in batches of 32 and shuffled each epoch. The test data is loaded in batches of 1 and not shuffled.
Background of Heterogeneous Graph Transformer (HGT)
Graph Neural Networks (GNNs) are specialized neural networks designed to handle graph-structured data, including homogeneous and heterogeneous graphs, and can process inputs with a flexible graph-like structure as opposed to a fixed grid-like structure. The Heterogeneous Graph Transformer (HGT) [11] is the ideal model for our application with spatial and time dependencies. The reason for preferring HGT over spatial temporal graph convolutional network (STGCN) [14] [15] is because HGT can capture all heterogeneous information in a single convolution layer, as opposed to splitting the spatial and temporal edges/graphs into separate graph convolution layers as in STGCN.
HGT has three components
- Meta relation-aware heterogeneous mutual attention
- Heterogeneous message passing from source nodes
- Target-specific heterogeneous message aggregation
Meta relation-aware heterogeneous mutual attention
The HGT handles different types of relations by their meta relations, i.e., <τ(s), φ (e), τ(t)> relations, in our case <τ(marker), φ (spatial), τ(marker)> and <τ(marker), φ (temporal), τ(marker)>. The head h for each edge e = (s, t) is calculated as
Where N(t) represents the neighborhood of node t, H_(l-1)[t] represents the l-layer node representation of the node t. The HGT projects the τ(s) type source node (marker) s into i-th key vector K^i(s) with a linear projection K-linear. It also projects τ(t) type target node(marker) t into an i-th query vector with Q-linear. After the projection, the HGT keeps a distinct edge-based matrix W_(φ (e))^(ATT) for each type of φ (e), so the model can capture different semantic relations. In addition, the HGT also applies a scaling tensor μ to denote the generali significance of each meta relation triplet. Finally, HGT concates h attention heads to get the attention vector. Then for each node t, HGT gathers all attention vectors from its neighbor N(t) and conducts with softmax to get the Attention_HGT(s, e, t).
Heterogeneous message passing from source nodes
Parallel to the mutual attention, HGT also calculates multi-head messages by the following equation.
The τ(s) type source node s to i-th message vector with M-Linear, the W_(φ (e))^(MSG) matrix for incorporation of the edge dependency. The final step is to concatenate all h message heads to get message Message_HGT(s, e, t).
Target-specific heterogeneous message aggregation
After collecting both multi-head attention and message, the HGT aggregates neighborhood information together for node t from all its neighbors(source nodes) of different feature distributions.
The final step is to map the target node t’s vector back to its type specific distribution by a A-Linear followed by a residual connection.
In our project, the final H^(l) layer was used to predict missing node/marker features (x, y, z) positions.
HGT model
We make use of the torch_geometric.nn.conv.HGTConv layer, which implements the HGT operator from [11]. We used 2 multi-head attentions, since there’s 2 types of edges to attend to. The metadata parameter contains the node and edge types for the heterogeneous graph, which will be identical for every training and test example. For our data, marker is the only type of node, so there’s 1 linear layer followed by ReLu activation at the beginning of the stack. After num_layers HGT convolutions, we apply a final linear layer to output predictions. In our application, we used 128 hidden channels, 3 out channels (x, y, z), 2 heads (spatial edge and temporal edge), 2 layers, 1 node type (only marker).
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads, num_layers, node_types, metadata):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in node_types:
self.lin_dict[node_type] = Linear(-1, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, metadata,
num_heads, group='sum')
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
for node_type, x in x_dict.items():
x_dict[node_type] = self.lin_dict[node_type](x).relu_()
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return self.lin(x_dict['marker'])
Training
Since we have a regression problem to predict the marker coordinates, we use MSE loss, which is the average squared error between the predicted and ground truth positions.
Since our data is loaded in mini-batches, we accumulate the loss proportional to the number of graphs in each batch.
def train(model, optimizer, train_loader):
model.train()
total_examples = total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
batch = batch.to(device, 'edge_index')
out = model(batch.x_dict, batch.edge_index_dict)
target = batch['marker'].y
loss = F.mse_loss(out, target)
loss.backward()
optimizer.step()
total_examples += batch.num_graphs
total_loss += float(loss) * batch.num_graphs
return total_loss / total_examples
Evaluation metric
Our evaluation metric is Mean Per Joint Position Error (MPJPE), the average Euclidean distance between the predicted and ground truth positions [13]. MPJPE takes the square root per squared position error, then averages over all markers and timesteps.
This has better physical intuition as the average error distance (in mm) compared to Root Mean Squared Error (RMSE). RMSE averages all the squared position errors over all markers and timesteps first, then takes the square root.
Here are similar papers that use MPJPE
- https://arxiv.org/pdf/2104.03020.pdf
- https://arxiv.org/pdf/2205.12583.pdf
- https://arxiv.org/pdf/1904.03289.pdf
- https://arxiv.org/pdf/2105.10882.pdf
@torch.no_grad()
def test(model, test_loader):
model.eval()
total_examples = total_mpjpe = 0
for batch in test_loader:
batch = batch.to(device, 'edge_index')
out = model(batch.x_dict, batch.edge_index_dict)
target = batch['marker'].y
# Mean Per Joint Position Error (MPJPE), joint positions are rows
mpjpe = torch.linalg.norm(out - target, ord=2, axis=1).mean()
total_examples += batch.num_graphs
total_mpjpe += float(mpjpe) * batch.num_graphs
return total_mpjpe / total_examples
Hyperparameter search
Hyperparameters are the parameters outside of the trainable ones in our model that influence model architecture and training. We use a package called Ray Tune to search for the optimal learning rate, batch_size, hidden_dim, and num_layers. It allows you to specify the hyperparameters and values to sample from, then Ray Tune trains trials of your model, pruning and continuing training the most promising ones according to a metric — in our case, loss.
tuner = tune.Tuner(
tune.with_resources(
tune.with_parameters(train_full, data_list=data_list),
resources={
"cpu": 4, # 16 cores / 4 cpus per trial = 4 concurrent trials
"gpu": 0,
}
),
tune_config=tune.TuneConfig(
num_samples=4,
scheduler=ASHAScheduler(
metric='loss',
mode='min',
grace_period=1,
),
),
param_space={
'lr': tune.loguniform(1e-6, 1e-2),
'batch_size': tune.choice([16, 32]),
'hidden_dim': tune.grid_search([32, 64, 128]),
'num_layers': tune.grid_search([1, 2]),
},
)
results = tuner.fit()
best_result = results.get_best_result('loss', mode='min')
best_result.config
Ray Tune initializes each trial with a “trainable” — a function that creates new DataLoaders and an HGT model with a different hyperparameter config, trains, then reports the loss back to the Ray session.
The optimal hyperparameter config we arrived at has a medium sized batch size of 32 and the largest hidden_dim and num_layers possible, indicating that a more expressive model performs better.
Further training
With the most promising hyperparameter config, we trained the model for 14,000 epochs, storing checkpoints with the model and optimizer’s state_dict, mpjpes, and losses, so training can be resumed later.
Results
Here is the training loss and test MPJPE over 14,000 epochs. The first significant drop in training is at 10 epochs, and it slowly continues to drop until 2500 epochs. The test MPJPE (averaged over all test examples) reaches a minimum of 4.1987 mm.
We can compare the predicted and ground truth marker positions as animated 3D points. It’s hard to see the difference since they overlap near perfectly for most timesteps. The MPJPE of the current frame is displayed in the upper right of the plot.
Here’s the observed marker positions, feature propagation fed into the HGT, the HGT’s predictions, and the ground truth positions. Note how feature propagation alone was unable to produce temporal cohesiveness.
We can also visualize the predictions as training progressed. Several hundred epochs were sufficient for accurate predictions, but were still a bit jittery, so we trained for longer.
Thanks for taking the time to read our post! Feel free to leave a comment if you have any questions.
Code
View our Colab to explore all the code discussed in this post:
https://colab.research.google.com/drive/18jfPEoYKTO0L7KQpTLoo7UAwkQzheJUz?usp=sharing
References
[1] J. M. Brazill, A. T. Beeve, C. S. Craft, J. J. Ivanusic, and E. L. Scheller, “Nerves in Bone: Evolving Concepts in Pain and Anabolism,” Journal of Bone and Mineral Research, vol. 34, no. 8, pp. 1393–1406, 2019.
[2] C. A. McGibbon, “Toward a better understanding of gait changes with age and disablement: neuromuscular adaptation,” Exercise and sport sciences reviews, vol. 31, no. 2, pp. 102–108, 2003.
[3] G. Martino, Y. P. Ivanenko, A. D’Avella, M. Serrao, A. Ranavolo, F. Draicchio, G. Cappellini, C. Casali, and F. Lacquaniti, “Neuromuscular adjustments of gait associated with unstable conditions,” Journal of neurophysiology, vol. 114, no. 5, pp. 2867–2882, 2015.
[4] J. Barth, J. Klucken, P. Kugler, T. Kammerer, R. Steidl, J. Winkler, J. Hornegger, and B. Eskofier, “Biometric and mobile gait analysis for early diagnosis and therapy monitoring in Parkinson’s disease,” Proceedings of the Annual International Conference of the IEEE Engineering in Medicine and Biology Society, EMBS, pp. 868–871, 2011.
[5] N. Giladi, “Medical treatment of freezing of gait,” Movement disorders: official journal of the Movement Disorder Society, vol. 23, no. S2, pp. S482 — -S488, 2008.
[6] J. R. Gage, M. H. Schwartz, S. E. Koop, and T. F. Novacheck, The identification and treatment of gait problems in cerebral palsy, vol. 180. John Wiley & Sons, 2009.
[7] M. G. D’Angelo, M. Berti, L. Piccinini, M. Romei, M. Guglieri, S. Bonato, A. Degrate, A. C. Turconi, and N. Bresolin, “Gait pattern in Duchenne muscular dystrophy,” Gait & posture, vol. 29, no. 1, pp. 36–41, 2009.
[8] O. Beauchet, C. Annweiler, M. L. Callisaya, A.-M. De Cock, J. L. Helbostad, R. W. Kressig, V. Srikanth, J.-P. Steinmetz, H. M. Blumen, J. Verghese, and Others, “Poor gait performance and prediction of dementia: results from a meta-analysis,” Journal of the American Medical Directors Association, vol. 17, no. 6, pp. 482–490, 2016.
[9] S. Nadeau, M. Betschart, and F. Bethoux, “Gait analysis for poststroke rehabilitation: the relevance of biomechanical analysis and the impact of gait speed,” Physical Medicine and Rehabilitation Clinics, vol. 24, no. 2, pp. 265–276, 2013.
[10] J. A. DeLisa, Gait analysis in the science of rehabilitation, vol. 2. Diane Publishing, 1998.
[11] Z. Hu, Y. Dong, K. Wang, and Y. Sun, ‘Heterogeneous Graph Transformer’, CoRR, vol. abs/2003.01332, 2020.
[12] E. Rossi, H. Kenlay, M. I. Gorinova, B. P. Chamberlain, X. Dong, and M. M. Bronstein, ‘On the Unreasonable Effectiveness of Feature propagation in Learning on Graphs with Missing Node Features’, CoRR, vol. abs/2111.12128, 2021.
[13] C. Ionescu, D. Papava, V. Olaru and C. Sminchisescu, “Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments,” in IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 36, no. 7, pp. 1325–1339, July 2014, doi: 10.1109/TPAMI.2013.248.
[14] S.Yan, Y.Xiang and D.Lin, “Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition”, AAAL 2018
[15] B. Yu, H. Yin, and Z. Zhu, “Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting”, ARXIV, 2017