Automated Theorem Proving with Graph Neural Networks

Daniel Jenson
Stanford CS224W GraphML Tutorials
10 min readMay 15, 2023

By Dan Jenson, Julian Cooper, and Daniel Huang as part of the CS 224W course project at Stanford University.

Interactive & Automated Theorem Proving (ITP & ATP)

Theorem proving is a difficult, lengthy process, which has historically been the sole employ of mathematicians. With advances in hardware and software, the late 80s and early 90s saw the advent of interactive theorem proving (ITP) assistants. These are software programs that aid mathematicians in proving statements by providing a formal syntax for composing proofs and a verification process for validating them.

Theorem Proving in Coq, from CoqGym [7].

More recently, some of these software programs have begun to provide automated theorem proving (ATP) components, which are capable of automatically selecting tactics and proving simpler sub-goals. Since 2019, these ATP modules have been further augmented with neural networks, creating a new class of ATPs called Neural Theorem Provers (NTPs). NTPs are often structured as neural network encoder-decoders together with a reinforcement learning (RL) component. The NTP embeds the proof goal as well as relevant context in an embedding space and then the decoder attempts to use this information to identify valid proof tactics. At runtime, an RL agent uses the output of the decoder, i.e. highly ranked tactics, and dispatches them to the interactive theorem prover, which, in turn, responds with failure or success and a list of new subgoals. This process repeats until the proof is solved or a pre-defined computation budget is exhausted.

Neural Theorem Proving (NTP)

Contemporary research on neural theorem proving since 2019 has used one of two platforms. The first uses the Coq ITP and originated out of Princeton University. The second uses the HOL Light ITP and was developed at Google. Coq is based on the calculus of inductive constructions (CIC) and uses dependent types. HOL Light, on the other hand, uses higher order logic and leverages classical logic. There are a number of papers on the design foundations of these two ITPs, so we will not spend more time on the underlying theory here.

For this project, we chose to extend CoqGym, which has an open source dataset and code repository. Currently, the HOList dataset is not publicly available (dead link), and only select portions of the code have been open-sourced. Accordingly, for the remainder of this post, we will describe CoqGym and our extensions to this platform.

The CoqGym Dataset

CoqGym is the platform developed by Kaiyu Yang and Jia Deng at Princeton University. It consists of 43,844 training, 13,875 validation, and 13,137 test proofs. Training the NTP encoder-decoder neural network uses proof steps, which are components of complete proofs, and there are 189,824 proof steps across the training and validation datasets.

Extracting proofs and proof steps from Coq projects is an expensive endeavor, both in terms of time and computation. Fortunately, CoqGym has some utilities for extracting proofs from source code written in Coq. The original dataset took over 4 hours to extract. However, we had to modify this process to also extract node features and edge indices for the abstract syntax trees (ASTs) that represent terms in Coq. This dramatically increased the computation time required, so we multi-processed extraction and implemented a number of modifications to their pipeline (detailed in the CoqGym-GNN Readme) to keep data extraction under a day.

Below is an example of a DataBatch object representing a proof step that is extracted and serialized during this process.

DataBatch(
x=[5320, 1],
edge_index=[2, 10546],
batch=[5320],
ptr=[48],
file='../data/mod-red/multired.json',
proof_name='sum_r'_i',
n_step=16,
env=[10],
local_context=[36],
goal={
id=1550,
text='Z.lt Z0 M',
ast=Tree(constructor_app, [Tree(constructor_const, [Tree(constructor_constant, [Tree(constructor_mpdot, [Tree(constructor_mpfile, [Tree(constructor_dirpath, [])]), Tree(names__label__t, [])]), Tree(constructor_dirpath, []), Tree(names__label__t, [])]), Tree(constructor_instance, [])]), Tree(constructor_construct, [Tree(names__constructor, [Tree(names__inductive, [Tree(constructor_mutind, [Tree(constructor_mpfile, [Tree(constructor_dirpath, [])]), Tree(constructor_dirpath, []), Tree(names__label__t, [])]), Tree(int, [])]), Tree(int, [])]), Tree(constructor_instance, [])]), Tree(constructor_var, [])])
},
tactic={
text='omega',
actions=[1]
},
is_synthetic=False,
tactic_actions=[1],
tactic_str='omega'
)

Training

Excluding data extraction, the end-to-end pipeline consists of three phases: (1) proof step encoding and batching, (2) tactic decoding with beam search, and (3) proof tree search with depth-first search. However, for training the encoder-decoder, only the first two steps of this process are relevant, since we can use teacher forcing with solved human proofs to teach the model which tactics are “correct” for a given proof step.

The phases of the Neural Theorem Prover.

The first step of the training process consists of creating embeddings for environment, local context, and goal terms for each proof step. The environment and local context together can consist of hundreds of terms, each of which is represented as its own AST. This means that a “batch” of proof steps might contain many thousands of ASTs. The original CoqGym paper defined the ASTactic model, which uses a TreeLSTM to create an embedding for each of these terms (ASTs). We, however, use PyTorch Geometric graphs with GraphSage or Graph Attention Networks (GAT) convolutional layers. Within each of these models, we also experimented with mean and max pooling. We also created shallow embeddings for node types in the ASTs, which dramatically improved performance. The embedding process is visualized below.

The encoder framework.

The decoder is a bit more subtle in its construction. At its heart, it consists of a Gated Recurrent Unit (GRU), which takes in input from an attention module, from the previous hidden state, and from an embedding of the current goal. The original attention module consisted of a two-layer feed forward neural network using a ReLU activation function. Incorporating guidance from the Design Space for Graph Neural Network by You et al., and Layer Normalization by Ba et al., we made the attention module considerably richer by adding a third linear layer, changing the activation function from ReLU to PReLU, and inserting layer normalization between layers.

This attention module attends to the environment and local context for a given goal by running it through the feed-forward network and then through a softmax function, which weights each embedding’s contribution to the final context attention vector. This process is represented below in the forward pass of the ContextReader.

class ContextReader(nn.Module):
def __init__(self, opts):
super().__init__()
...

def forward(self, states, embeddings):
"""forward pass through attention module"""
assert states.size(0) == len(embeddings)
context = []
for state, embedding in zip(states, embeddings):
if embedding.size(0) == 0: # no premise
context.append(self.default_context)
else:
input = torch.cat(
[state.unsqueeze(0).expand(embedding.size(0), -1), embedding], dim=1
)
weights = self.layer_norm1(self.prelu1(self.linear1(input)))
weights = self.layer_norm2(self.prelu2(self.linear2(weights)))
weights = self.linear3(weights)
weights = F.softmax(weights, dim=0)
context.append(torch.matmul(embedding.t(), weights).squeeze())
context = torch.stack(context)
return context

The following corresponds to the modified attention mechanism.

Modified attention mechanism.

Coq uses “tactics” to expand the current proof tree until it forms a complete proof (or fails). These tactics are also represented as trees, and the tactic decoder uses the output of the GRU to inform beam search over tactics and to, hopefully, yield viable high value tactics that can be used in the current proof step.

Tactic Decoder and Attention Module.

Testing

The testing phase involves a reinforcement learning (RL) agent that interacts with the Coq ITP. In this phase, a goal is presented along with its starting environment and local context. At each (proof) step, the encoder embeds this information, the decoder generates a list of potential tactics, and the agent dispatches those tactics to the ITP and receives a failure a success signal with a list of new subgoals. This node then becomes part of the “frontier” of the proof tree. The agent explores this frontier using depth-first search until all subgoals have been proved, thus completing the proof, or a pre-determined budget has been exhausted. By default, CoqGym uses 300 tactics and 10 minutes per proof as the budget, and we adhered to these limits.

It is important to note that testing takes a very long time. Each proof has to be proved by expanding its proof tree until complete. And if a proof fails, the agent backtracks and attempts to apply different tactics a various internal nodes in the proof tree. Furthermore, when multi-processing tests, each process has to load its own copy of the neural network, which can quickly become expensive.

Given that each proof can take up to 10 minutes or utilize up to 300 tactics, completing 13,137 would take approximately 46 days on a single core if each proof completed in 5 minutes. Even multi-processed across 12 cores, this testing process would take nearly 4 days. For this reason, we limited our tests and iterations to a small subset of the proofs, the ZFC project, which consists of 237 proofs regarding Zermelo-Fraenkel set theory.

Design Space and Experiments

Our changes fall into three principal dimensions: (1) encoder model, (2) decoder attention mechanism, and (3) feature representation. With respect to (1), we tested GraphSage and Graph Attention Network (GAT) convolution layers. Within each of these models, we also test mean and max pooling. For (2), we tested both the original as well as our rich attention mechanism. And for (3), we tested both one-hot encoding of node types and shallow feature embeddings. We hypothesized that shallow feature embeddings would allow node types that behave similarly in term ASTs would be embedded more closely in the embedding space and improve the learning capacity of our models (and it did). Below is a trimmed version of the TermEncoder with the relevant elements exposed.

from torch_geometric.graphgym.models.encoder import IntegerFeatureEncoder

class TermEncoder(torch.nn.Module): # StackGNN
def __init__(self, opts):
super(TermEncoder, self).__init__()
self.opts = opts
...

# feature encoder
self.feature_encoder = IntegerFeatureEncoder(self.input_dim,
len(nonterminals))

# conv layers
conv_model = self.build_conv_model(opts.model_type)
self.convs = nn.ModuleList()
self.convs.append(
conv_model(self.input_dim, self.hidden_dim // opts.heads, heads=opts.heads)
)
assert opts.num_layers >= 1, "Number of layers is not >=1"
for i in range(opts.num_layers - 1):
self.convs.append(
conv_model(
self.hidden_dim, self.hidden_dim // opts.heads, heads=opts.heads
)
)

# post message passing
self.post_mp = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.Dropout(opts.dropout),
nn.Linear(self.hidden_dim, self.output_dim),
)

self.dropout = opts.dropout
self.num_layers = opts.num_layers

# pooling
self.pool = pyg_nn.global_max_pool # self.pool = pyg_nn.dense_diff_pool
self.post_pool = nn.Linear(self.output_dim, self.output_dim)

def forward(self, proof_step):
"""forward pass through graph neural network"""
# preprocess and unpack batch
proof_step = self.feature_encoder(proof_step)
x, edge_index, batch = proof_step.x, proof_step.edge_index, proof_step.batch

# move x, batch and edge_index onto device
batch = batch.to(self.opts.device)
edge_index = edge_index.to(self.opts.device)
x = x.to(self.opts.device)

# apply forward pass logic
for i in range(self.num_layers):
x = self.convs[i](x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)

# post message passing
x = self.post_mp(x)

# global pooling
x = self.pool(x, batch)

return self.post_pool(x)

The convolutions in the above term encoder refer to either the GraphSage or Graph Attention Network. Below, we show our GAT code and the attention weight equation. In our testing, we always used 2 heads, each with dimension 128.

class GAT(MessagePassing):
def __init__(self, in_channels, out_channels, heads=2, negative_slope=0.2,
dropout=0.0, **kwargs
):
super(GAT, self).__init__(node_dim=0, **kwargs)
...

# linear layers
self.lin_l = torch.nn.Linear(in_channels, out_channels * heads, bias=False)
self.lin_r = self.lin_l

# attention linear layers
self.att_l = nn.Parameter(torch.zeros(self.heads, self.out_channels))
self.att_r = nn.Parameter(torch.zeros(self.heads, self.out_channels))

def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
"""derive attention weights and apply to neighborhood messages"""
# attention weights
alpha = F.leaky_relu(alpha_i + alpha_j, self.negative_slope)

# softmax
alpha = pyg_utils.softmax(alpha, index=index, ptr=ptr, num_nodes=size_i)

# apply dropout
alpha = F.dropout(alpha, p=self.dropout, training=self.training)

# multiply embeddings
out = alpha.unsqueeze(-1) * x_j

return out

def aggregate(self, inputs, index, dim_size=None):
"""sum aggregation of neighborhood messages"""
out = torch_scatter.scatter(
inputs, index=index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
)

return out

def forward(self, x, edge_index, size=None):
"""forward pass through graph attention convolution"""
H, C = self.heads, self.out_channels

# pre-processing
x_l = self.lin_l(x).view(-1, H, C)
x_r = self.lin_r(x).view(-1, H, C)

alpha_l = (x_l * self.att_l).sum(dim=-1)
alpha_r = (x_r * self.att_r).sum(dim=-1)

# message propagation
out = self.propagate(
edge_index=edge_index, alpha=(alpha_l, alpha_r), x=(x_l, x_r), size=size
)

# post-processing
out = out.view(-1, H * C)

return out

In phase one of experimentation, we trained models for 4 epochs and evaluated them on proofs in the ZFC dataset. The following table shows our results from this exploration phase.

Given these results, we trained the following models for 10 epochs: (1) a GraphSage model with max pooling, shallow feature embeddings, and rich attention, and (2) a Graph Attention Network (GAT) with mean pooling, shallow feature embeddings, and rich attention.

Next Steps

While we have demonstrated that GNNs can improve performance on at least a subset of the test proofs in CoqGym, we hope to expand both our experimentation and testing. In our next iteration, we plan to experiment with Differential Pooling, which constructs graph-level (AST) embeddings through a sequence of hierarchical pooling layers. The HOList [5] paper also found that performance increased on their GNN implementation up to 12 hops. We only use 2 hops, so this may be an easy area for significant improvement. We also plan on testing ID-GNNs, which are provably more expressive than GraphSage or GAT models, to create richer node embeddings. While we have learned a great deal from working with the CoqGym dataset, we have learned of a new dataset, Isabelle’s Archive of Formal Proofs (AFP), which may be even easier to use and permit faster model iteration. This is the same dataset that Google’s HOList team have recently switched to and so it will be exciting to compare our results with their benchmarks in our next iteration.

References

[1] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton. “Layer Normalization”. 2016. arXiv: 1607.06450 [stat.ML].
[2] Yutian Chen et al. “Bayesian Optimization in AlphaGo”. In: CoRRabs/1812.06855 (2018). arXiv: 1812.06855. url: http://arxiv.org/abs/1812.06855.
[3] Chi Thang Duong et al. “On Node Features for Graph Neural Networks”. In: CoRR abs/1911.08795 (2019). arXiv: 1911.08795. url: http://arxiv.org/abs/1911.08795.
[4] William L. Hamilton, Rex Ying, and Jure Leskovec. “Inductive Representation Learning on Large Graphs”. In: CoRR abs/1706.02216 (2017). arXiv: 1706.02216. url: http://arxiv.org/abs/1706.02216.
[5] Aditya Paliwal et al. “Graph Representations for Higher-Order Logic and Theorem Proving”. Sept. 12, 2019. doi: 10. 48550 /arXiv.1905.10006. arXiv: 1905.10006[cs,stat]. url: http://arxiv.org/abs/1905.10006 (visited on 01/20/2023).
[6] Petar Velikovi et al. “Graph Attention Networks”. 2018. arXiv: 1710.10903 [stat.ML].
[7] Kaiyu Yang and Jia Deng. “Learning to Prove Theorems via Interacting with Proof Assistants”. May 21, 2019. doi: 10.48550/arXiv.1905.09381. arXiv: 1905.09381[cs,stat]. url: http://arxiv.org/abs/1905.09381 (visited on 01/20/2023).
[8] Jiaxuan You, Rex Ying, and Jure Leskovec. “Design Space for Graph Neural Networks”. In: NeurIPS. 2020.

--

--