Trying to scale test-time compute with LLMs

6 min readOct 17, 2024

Recently came across a research paper from deepmind that talks about how scaling test time compute is better than scaling LLMs with more parameters and data. We are hitting a wall in terms of data and experimenting with bigger architectures is very very expensive.

OpenAI’s recent o1 model seems to be doing the same thing too. According to their blog post, they allow the model to “think” and “reason” for longer — which is a cue for generating more tokens. They used Reinforcement Learning to generate and refine its chain of thought. Their post also shows how these models have gotten much better at coding, math, and reasoning problems but they still perform similarly to GPT-4o when it comes to tasks like drafting and editing text.

I decided to try building something that does the same. Due to limited resources, I did not have the chance to fully explore the techniques mentioned in the paper and I went off on a slight tangent — which I talk about in the last part of this post

What is chain of thought and why does it help the model perform better?

Simply put, the model attempts to break the query down into simpler and smaller problems which “steers” the model to come to a logical answer. I use the word “steer” because an LLM is a next token predictor; with better/logical context, it could produce the token which is more in line with all of the previous context. This research paper demonstrates and quantifies the boost in performance by using chain-of-thought. The key idea is that the thoughts act as a bridge between the input and output.

Also, note that LLM’s are few-shot learners. With examples of thought processes and solutions in the input prompt, you can achieve better and sometimes grounded results. The evaluation metrics used for models like gemma, llama etc use few-shot chain of thought as part of the prompt.

pass@k and zero-shot prompting is what we should be aiming for ideally in my opinion. It is not a good user experience to say “try running the prompt 5 times and you should get the correct answer once”. This shows more about the limitations of the LLM and the transformer technology.

Source: https://arxiv.org/pdf/2201.11903

Tree-of-thoughts:

LLM’s are generative in nature and non-deterministic. How about we use this property to exploit the model’s creativity? Spin off multiple chains of thought, and then let the model go on a rabbit hole to pick the best possible answer. This is essentially the idea behind it.

Source: https://arxiv.org/pdf/2305.10601

Reflexion:

Ask the model to critique the input and it will generate some feedback. This feedback is then sent back to the model (with the input) and asked to improve based on the feedback. The authors argue that this self-reflective feedback acts as a ‘semantic’ gradient signal by providing the agent with a concrete direction to improve upon, helping it learn from prior mistakes to perform better on the task.

Beam Search:

This algorithm helps you in the sampling part of the generation process. When you get the probability distribution of the next token, there are a bunch of methods you can use to sample the next token — greedy, beam, top-p (nucleus) sampling etc. Greedy sampling, as the name suggests — selects the token with the highest confidence. A major drawback of this is that it would lack creativity and be predictable. Beam search (in the case of generating output) takes another approach by considering the top few outputs (this number is decided by a parameter known as the “beam width”) and only those are considered.

Let’s consider a beam-width of 3 for this example. At each step the model predicts the probabilities for all possible tokens, so for a beam width of 3, only the top-3 are considered based on cumulative probabilities — which is either the product of all probabilities thus far (or the sum of all log probs)

Implementation:

With this in mind, we can now continue to implement our bootleg-o1. Here is the rough outline on how it works:
1. Generate steps on how to solve the problem (only the steps to solve, not the solution) — let’s call this the ideation phase
2. With the generated ideas, ask the model to generate a solution based on it.
3. Ask the same model to critique the generated solution and the steps taken and rate it from 1–10 (1 being very flawed and 10 being perfect solution)
4. With the feedback obtained from step 3, we ask the model to reflect and iterate on the generated solution
5. Based on the scores provided from step 3, we take the beam_width number of children on to the next stage of refinement and critiquing.
6. Repeat steps 3 to 5 until stop condition

**(to my understanding, the original paper appears to append all reasoning and thinking tokens to the input from the beginning, meaning that at stage ’n’, they incorporate all tokens from stages 1 through n-1 — in my implementation, I only take the latest one)**

You can take this a step further by fine-tuning your small model to produce valid chains of thought, feedback, and refinement. You can do this by making a bigger model, say GPT-4o, Claude-3.5-sonnet etc to produce its intermediate thought/critique/refinement process and use it to tune the smaller model. I wanted to do this via DPO (direct preference optimization) but unfortunately couldn’t due to limited compute. Instead, I have fine-tuned (with LoRA) the small language model with the outputs of the bigger model.

I have not had the chance to evaluate the MATH dataset fully, but based on a vibe check, there are performance gains.

You could stop the refinement-critique process in some ways like regex (if applicable for your use case) or ask the LLM once again if it thinks it’s the final answer. In my Implementation, I have let run for max_depth (3, in my case) regardless of whether it found the answer or not. I did it so under the impression that multiple rounds of feedback and refinement could yield a better answer.

Anything more than a depth of 3 resulted in a loop. There are no real gains after as it keeps suggesting the same improvements and refines very little. Also having tried it on a not-so-good model (gemma-1.1–2b-it model), the critique and refinement produced mostly look like junk (hallucinated). It felt like a blind person leading another blind person. The generated idea was sometimes wrong, and the feedback for improving it was also wrong. Observing these intermediate results would make you think that the entire process looks like a hallucination in a loop (hallucination loop, if you will).

After having tried it with better models such as gemma-2–2b-it, GPT-4o-mini and claude-3-haiku, it developed much better strategies and refinement — even for harder questions.

All in all, test-time-compute might be the strategy for accuracy in answers. I see this going down in 2 ways

  1. If your application/use-case has no issue with higher latency and wants to prioritize accuracy, you could use this approach for a whole lot cheaper cost.
  2. If your application/use-case cannot have high latency, then this is probably not the best approach.

Final thoughts and remarks:

The base gemma-1.1–2b-it could now pass level 1 and (some) level 2 math questions. I am talking about pass@1 and 0-shot CoT. The metrics that you see in the model card are usually passed as a few-shot CoT prompt. GPT-4o can solve level 5 questions correctly.

Although I went off on a slight tangent from the original paper. This seems to work well too. The changes from the original paper are:

  1. The original paper uses all the reasoning tokens from the intermediate stage(s).
  2. The original paper also uses another variation of beam search called look-ahead search — which explores a sub-tree and then based on the score, rolls back to the previous parent node. This is more time consuming.

The method of scaling test-time-compute itself isn't new. Researchers have already worked on using other tree search algorithms like MCTS (Monte-Carlo Tree Search), Q* etc

Up until now, we saw scaling laws primarily consisted of throwing in more data and more compute. Now we might see a test-time-compute being a part of this as well.

The implementation can be found here on my github. Usage and requirements are mentioned in the repo.

--

--

No responses yet