SIaM: Self-Improving Code-Assisted Mathematical Reasoning of LLMs
Potential of improving LLMs by leveraging large-scale,expert-written, diverse math question-answer pairs remains unexplored. To utilize these resources and tackle unique challenges such as code response assessment, authors of this paper[1],propose a novel paradigm that uses a code-based critic model to guide steps including question-code data construction, quality control, and complementary evaluation.
Key contributions:
- first attempt to leverage large-scale web QA pairs to improve the code-assisted mathematical reasoning abilities of LLMs.
- propose a novel iterative self-improving paradigm that employs a new critic model to guide various steps such as data construction and filtering. This critic model can also serve as a complementary evaluation scorer, reducing the reliance on heuristic design for new evaluation tasks.
- Extensive experiments on both English and Chinese tasks demonstrate the effectiveness of our paradigm, and our comprehensive analysis of the key factors in achieving continuous improvement at different stages may shed light on future studies
Method
i) Training an Initial Model
- first use high-quality seed data to fine-tune an LLM, resulting in model Mseed.
- use Mseed to generate code samples and keep up to four predictions per question wherein the execution result of the code matches the reference answer and combines the seed data and the self-distilled data to train M0, which is further used as the initial model for later stages
ii) Building a Multi-use Code-based Critic Model
several challenges arise in data utilization, filtering, and evaluation, like:
- pattern-based methods to compare predictions and ground truth answers during validation and evaluation, works well for GSM-style datasets, where answers are single numbers and well-formatted. However, pattern-based methods face inherent challenges in handling diverse answer types and formats and bridging the gap between natural language and programming language. For example, with the MATH dataset, comparing CoT predictions with reference answers in LaTeX-like format already requires humanwritten patterns and answer conversion
- This complexity is compounded when predictions are presented in code syntax, even when the task is simplified to compare the reference answer with the code execution result.
To address the above challenges, authors propose building a code-based critic model optimized by the following objective:
where q denotes a question, a is the reference answer to q, c represents the code response to q, and e is the execution result of code c.
iii) Code Data Generation
- for all questions, authors only use their reference answers to verify the correctness of code execution results instead of directly training on these answers, and we only use benchmarks’ training sets.
- In the k + 1-th iteration, for each new question, we use the current policy model πθk to generate five code samples and execute them to obtain the results. For questions in the diverse-format web data, the critic model is then used to predict YES or NO for each response (ai, cij , eij ) given qi.
- We use the probability of YES or NO as the confidence value for the critic model’s judgment.
- A higher probability score indicates a greater confidence in the code response, either agreeing with or disagreeing with the reference answer
iv) Self-Improvement with unseen data
- authors perform supervised fine-tuning (SFT) on πθk using DSFT:
where λ1, λ2 represent thresholds for filtering and difficulty control.
- further leverage negative instances by optimizing the policy model on preference data using algorithms such as DPO [2] and ORPO [3]
- mainly focused on DPO and leave other options for future studies, and we jointly train the policy with the SFT objective to alleviate overfitting to the preference data and ensure a stable update[3]
- For each question, authors use the highest-scoring YES response and the highest-scoring NO response to form a preference pair, aiming to maximize the difference between them
Experiments
i) Data
- Statistics of training data used in our three stage paradigm (D1 and D2,in-house are Chinese resources; D2,WebInstruct is English-dominant
a) Seed Data D0:
- used GPT-4–0613 to generate Python code in an iterative fashion: we repeatedly sample the remaining questions that do not have correct code (i.e., the code execution results match the reference answer of the questions) for up to three iterations
b) Value-Style D1:
- utilize the initial policy M0 to generate code samples to questions in training sets of two open-source word math problem datasets APE (200.5K) and CM (13.6K),both collected from educational websites covering elementary and middle-school levels
c) Diverse-Format Data D2 and Critic Data:
- For each question, we retain only one positive code and one negative code (if any exists) judged by the critic.
- To evaluate the generalization and robustness of our paradigm, we also use a recently released large-scale QA dataset named WebInstruct to construct a similar-scale D2, containing 447K preference pairs
ii) Implementation
- used LLLAMAFACTORY [4] for efficient fine-tuning built upon DeepSpeed (ZeRO-3).
- experimented with various LLMs to select backbone models such as CodeLlama-7B-Python , Llama3instruct , CodeQwen1.5–7B-Chat, QWEN2, and Deepseek-Coder-7B-instruct-v1.5 , which demonstrate strong coding capabilities on code-related benchmarks.
iii) Performance of Initial Policy and Self-improved LLMs
- Table below outlines Accuracy across the development sets of math datasets, where DeepSeekcode, Llama3instruct, QWEN2Mathinstruct used as initial policy models (i.e., M0) for self-improvement, because of noticed superior average performance while investigating models when trained with seed data
- Table below outlines Impacts of different stages and data selection on the development sets of datasets
- observe that self-improving the initial policy model with Chinese-only data, D1 and D2, does not hurt the accuracy of M2 on English tasks. In fact, it may be beneficial (e.g., +1.5% on both MATH and GSM8K datasets using DeepSeekcode).
iv) Comparison of different data choices and alignment methods
a) Diversity
- Table below shows self-improving average accuracy of Llama3instruct on the development sets of different datasets with various training strategies and data.
- Based on the experimental results, given D0 and D1, we observe that two-stage SFT (first on D0 for two epochs and then on D1 for two epochs) under-performs one-stage SFT (over the concatenation of D0 and D1 for two epochs)(B vs. C in table above)
- incorporating D2 using either strategy achieves similar performance (E vs. F in table above)
b) Denoised SFT Data
- used the code-based critic model to construct SFT data
- Experimental results show that we can achieve similar average accuracy using either D2,SFT,H or the D2,SFT (D vs. E in Table above)
- D2,SFT,H is only 30.6% of the latter’s size, indicating the usefulness of the filtering
c) DPO or SFT
- Based on a reasonably good model M1 (trained with D0 and D1, such as C in Table above, we can either self-improve it via SFT or DPO
- compare using the
- positive (question, code) pairs in the DPO data for another round of SFT, which results in a 1.8% drop in accuracy on downstream tasks (G vs. I in Table above)
d) DPO with SFT
- experiments indicate that DPO training is relatively insensitive to the weight of the SFT loss
- tested with λ = 1.0 and λ = 2.0, both of which resulted in similarly good performance (77.8%).
- Table below shows impact of the weight of the SFT loss in DPO training on the average accuracy and average response length in words on GSM8K and CMATH (L0: response length of reference policy).
- In table above, removing the SFT loss (i.e., λ = 0) from DPO training leads to a dramatic increase in response length, especially for Chinese tasks such as CMATH, and yields worse results than the reference policy model
e) Other Diverse-Format Resources
- also experiment with constructing similar-scale preference data using the diverse-format D2 based on WebInstruct
- resulting improvement in average accuracy is less substantial compared to that achieved with the Chinese diverse-format D2 (+0.9% vs. +1.8% on Llama3instruct ; +0.6% vs. +2.5% on QWEN2Mathinstruct
v) Using Critic model as Evaluator
- All scores are computed by comparing predictions with ground truth answers, using heuristics-based exact match (EM) following previous studies for fair comparison
- To explore the potential of using the critic model as a complementary evaluator, we examine the correlation between the two evaluation methods on the previously used benchmarks. We use the original ground truth answers (final-step answers if answers are COT-style)
- As shown in table below, there is a very strong correlation (0.79) (compared to the very-strong-cutoff value 0.71 and strong-cutoff value 0.49 between the scores computed by the two evaluators
vi) Performance of Self-improved LLMs on More Out-of-Domain Tasks
- Table below shows OOD accuracy on MathBench (⋆: scored by the critic model; †: based on the numbers
- Results above shows that self-improved models demonstrate substantial gains on both subsets, with an accuracy improvement of 4.4%.
- On both subsets, the self-improved model consistently outperforms the initial one across all educational levels and subjects with notable improvements particularly in middle school tasks and English theoretical tasks
Conclusion
- introduce a novel paradigm for improving LLMs, which employs a code-based critic model to guide stages such as the creation and filtering of question-code data as well as complementary evaluation
- investigate various alignment algorithms using self-generated instruction/preference data for further improvement
- Results show the effectiveness of self-improving LLMs with this proposed paradigm
References:
- SIaM: Self-Improving Code-Assisted Mathematical Reasoning of Large Language Models by Yu et al.arXiv:2408.15565
- Direct preference optimization: Your language model is secretly a reward model. by Rafailov et al. Advances in Neural Information Processing Systems
- Orpo: Monolithic preference optimization without reference model. by Hong et al. arXiv preprint arXiv:2403.07691
- Llamafactory: Unified efficient fine-tuning of 100+ language models. by Zheng et al. arXiv preprint arXiv:2403.13372