AlphaMath: Elevating Mathematical Reasoning in Large Language Models with Monte Carlo Tree Search

Paulstoienescu
7 min readSep 25, 2024

--

Introduction

Recent advancements in large language models (LLMs) have significantly improved their mathematical reasoning capabilities. However, these models still face challenges when tackling complex problems that demand multiple reasoning steps, often resulting in logical or numerical errors. In this study, the authors present a novel approach that eliminates the need for process annotations (from humans or GPTs) by employing the Monte Carlo Tree Search (MCTS) framework. This method autonomously generates both process supervision and step-level evaluation signals. By iteratively training the policy and value models, the approach harnesses the strengths of a well-pretrained LLM to gradually refine its mathematical reasoning abilities.

In this research, the authors aim for LLMs to develop a human-like capability for self-evolution, enabling them to autonomously enhance their use of knowledge. They hypothesize that well-pretrained LLMs already have the requisite mathematical knowledge to produce correct reasoning but need proper stimulation — such as refined prompts or an optimized search strategy — to effectively do so. In this work, solutions involving both textual analysis and code snippets are autonomously generated by a well-pretrained LLM, which is guided by carefully crafted prompts and a thoughtfully designed Monte Carlo Tree Search (MCTS) framework.

To improve the efficiency of solution generation, the authors integrate a value model into the same LLM by adding a linear layer. This enhancement eliminates the need for time-consuming rollouts to estimate rewards. As the LLM learns to solve mathematical problems from its own annotated solutions, the value model concurrently learns to evaluate the quality of intermediate reasoning steps based on the corresponding state values in the MCTS, mimicking human judgment. During inference, the inclusion of the value model allows the LLM to perform MCTS inference, which substantially boosts its reasoning abilities, though it remains constrained by efficiency. To address this, the authors introduce a step-level beam search strategy inspired by the beam search algorithm. In this approach, the value model is designed to aid the policy model (i.e., the LLM) in exploring more effective solution paths, rather than depending solely on prior probabilities. This step-level beam search significantly improves the LLM’s reasoning capabilities at a minimal computational cost, compared to traditional greedy or MCTS inference methods.

AlphaMath

For any given input question, the solution process can be divided into multiple reasoning steps (e.g., segmenting the solution by distinct stages or simply by punctuation like periods). From this standpoint, mathematical problem-solving is framed within a reinforcement learning context. Specifically, consider a complete solution comprising T reasoning steps. At any given time, the partial solution is represented as the state, while the next potential reasoning step is treated as the action. In this framework, the policy model is represented by a large language model, and the transition from one state to the next is deterministically achieved through a concatenation operation.

The primary objective of the authors is to develop a step-level value model capable of evaluating the expected returns from the current partial solution, thereby guiding the LLM to select more appropriate subsequent reasoning steps.

Before the (k+1)-th round of training, the authors have a value model and an LLM policy model, which, in their study, are essentially the same model but with different final layers. Leveraging these models, they construct an inference algorithm powered by MCTS. This algorithm begins with the initial state as its root and, through the combined capabilities of the policy and value models, systematically expands the search tree by adding new nodes. These nodes correspond to states that are identified as having high potential based on the outcomes of simulated trajectories. Specifically, within the context of mathematical problem-solving, the authors customize the four key operations of the MCTS algorithm accordingly:

  1. During the i-th simulation of the MCTS, the process starts with s0​, which represents the initial state containing the input question. The algorithm then continues to explore the tree by selecting nodes based on a variant of the PUCT (Predictor + Upper Confidence Tree) algorithm. This variant balances the trade-off between exploring less-visited nodes and exploiting those that appear promising based on the accumulated rewards and predicted probabilities.

2. Back-tracing from the selected leaf node to the root constructs a partial. solution, which serves as a prompt for further node expansions. In this context, since the LLM can theoretically produce an unlimited number of potential actions (token sequences), the authors employ sampling generation with a higher temperature to ensure diversity in the generated actions. This approach helps to explore a broader range of potential solutions during the search process.

3. The evaluation of the leaf node or partial solution, identified after the selection phase, is performed using a weighted sum. If the expanded node is terminal (i.e., it represents a complete solution or end state), the corresponding reward is returned. Otherwise, the value is predicted by the value model, providing an estimate of the expected return from that state onward. This predicted value helps guide the search process in subsequent steps.

4. At the end of the i-th simulation, a backward pass update is performed on each edge along the path from the leaf node back to the root. The updates to their state-action values and visit counts are executed according to the following rules:

5. After running N simulations with the MCTS algorithm, the final tree is obtained, storing the expanded nodes and their corresponding state-action values. Given that the transition function is deterministic, these Q values can be utilized as training signals. This allows to directly fit the state-action values of non-terminal nodes as follows:

Training

Initially, the approach starts with a pretrained LLM serving as the policy model. This model is then extended by incorporating an auxiliary linear layer with a tanh activation function. This additional layer works in parallel with the traditional softmax layer, which is responsible for token prediction. The auxiliary layer allows the model to learn value predictions, enabling it to assess the quality of reasoning steps and guide the selection of more effective solutions during the problem-solving process.

From the tree constructed during the k-th round of MCTS, solution paths can be sampled that correspond to terminal nodes with both correct and incorrect predicted answers, along with the value estimations of each node along these paths. A multi-task loss function is then applied to update both the policy and value models. This multi-task loss allows the model to simultaneously learn from the value predictions (assessing the quality of intermediate steps) and the correctness of the final answers, thereby enhancing its ability to navigate more effective reasoning paths in future iterations.

Inference

For MCTS inference, once the tree has been fully constructed, the algorithm iteratively selects the top-B1​ steps (where B1​ is usually set to 1 in MCTS) from the root in a top-down manner. This selection process is guided by the maximum Q-value stored in the child nodes of the tree. Following this, all child nodes from the previously selected B1​ steps are collectively re-ranked based on their Q-values. The top-B1​ nodes from this re-ranked list are then retained for the next iteration. This process continues until the most promising path is determined, enabling the algorithm to identify the optimal solution path efficiently.

However, MCTS is computationally intensive due to the numerous simulations required, making it less practical for use in production environments. To address this limitation, the authors modify the MCTS inference process by removing the backup operation and introducing a simplified method called Step-level Beam Search (SBS). Unlike MCTS, this approach does not construct the entire tree; instead, it dynamically selects the best child node during each node expansion. This streamlined selection process allows SBS to be more efficient while still guiding the LLM through effective reasoning paths.

Results

Unlike previous works, the proposed AlphaMath framework does not rely on high-quality solutions annotated by humans or GPT-4, whether they are in the form of text analysis or code snippets. While such annotations can enhance a model’s reasoning capabilities, they come with considerable annotation costs. Additionally, the AlphaMath approach stands apart from prior research by not using any external datasets (e.g., additional questions and solutions) beyond the GSM8K and MATH datasets, demonstrating its ability to perform effectively without external annotated resources.

--

--