Reinforcement Learning for tuning language models ( how to train ChatGPT )

ML Blogger
6 min readDec 11, 2022

--

Large Language Models

The Large Language Model revolution started with the advent of transformers in 2017. Since then there has been an exponential growth in the models trained. Models with 100B+ parameters have been trained. These pre-trained models have changed the way NLP is done. It is much easier to pick a pre-trained model and fine-tune it for a downstream task ( sentiment, question answering, entity recognition etc.. ) than training a model from scratch. Fine-tuning can be done with a much smaller set of examples than training a model from scratch making the whole process of NLP much easier.

LLMs with billions of params ( > 10B+ ) also display emergent abilities i.e the model performance shows a significant jump ( on various tasks ) once a certain number of parameters are crossed. These large models also display amazing few-shot learning capabilities & prompt enabled tuning.

Using LLMs

As talked about earlier LLMs can be used in various ways i.e

  • Prompting : In the prompting paradigm, a pre-trained LLM is provided a snippet of text as an input and is expected to provide a relevant completion of this input. In prompt engineering, the description of the task is embedded in the input, e.g., as a question instead of it being implicitly given
  • Fine-tuning : Fine-tuning is a way of applying or utilising transfer learning. Specifically, fine-tuning is a process that takes a model that has already been trained for one given task and then tunes or tweaks the model to make it perform a second similar task.

The recent success of ChatGPT has shown us how fine-tuning can improve the models performance by leaps and bounds. In this article we will look the method used by ChatGPT called RLHF ( Reinforcement Learning from Human Feedback ).

Using RL with human feedback for fine-tuning

For fine-tuning LLMs using RL we need to frame the problem into a Agent-Environment setting where the agent ( policy ) can interact with the environment to get the reward for its actions. This reward is then used as feedback to train the model.

The mapping of the entities is as follows

  • Agent ( Policy ) : LLM ( Large Language Model )
  • Environment : In this case the environment is the reward function ( model ) which generates rewards. The reward function consumes the input as well as the output of the LLM to generate the reward

The reward is used in a loss function and the policy is updated.

overview of the approach

Policy

This is the pre-trained LLM which is being fine-tuned.

Reward Model

Before the reward model is trained data is collected from human labelers. For each input x several yᵢ are generated by sampling from the LLM. The humans are then asked to rank these yᵢ giving the highest rank to the best response. Using this as the label the reward model is trained to maximise the probability of the correct response using a loss of the type

Eq 1

The reward model can also be a ranking type loss function where the model tries to maximise the order of the ranking of the outputs rather than maximise the probability of one the outputs.

Loss Function

The loss function for training the policy is

Eq 2

where r(x,y) is the reward model output and the second term is the KL divergence to ensure the policy π doesn’t deviate too far from the ρ language model while fine-tuning.

For optimising the loss function PPO ( Proximal Policy Optimisation ) algorithm is used. The reason to use PPO is because the optimization is ‘on policy’ i.e we are directly optimising the policy. PPO is similar to TRPO and offers more stable updates to the policy than other optimisers for policy gradient methods.

While training only the policy is updated and the reward model is not optimised. The reason is that the reward model is only a ‘proxy’ for the human preferences which is trained using a handful of examples and optimising it along with the policy causes overfitting.

Overview of the training process

The training process is as follows

  • Gather samples ( x, y, y, y, y ) via x ~ D, yᵢ ~ ρ(.|x) . Humans pick the best yᵢ from for each x.
Step where samples are gathered from the language model for human feedback
  • Reward model is initialised to ρ and trained on the human annotated samples using the loss in Eq 1.
  • Train π using PPO ( Proximal Policy Optimisation ) with loss as in Eq 2.
final step of fine-tuning the policy with the reward model

Open Source Libraries

TRL

TRL is library for training language models with RL. With trl you can train transformer language models with Proximal Policy Optimisation (PPO). The library is built on top of the transformer library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point only decoder architectures such as GTP2 are implemented.

  • PPO optimiser is used for training
  • Currently supports GPT2 model for training

RL4LM

RL4LM is a library for training language models using RL by Allen Institute for AI. Their approach and enhancements are presented in this paper here. The following are the highlights of their library

  • Multiple NLP tasks supported like summarisation, sentiment, translation etc..
  • Supports multiple metrics like BLEU, ROUGE, BERTSCORE etc..
  • On Policy algorithms supported like PPO, TRPO, A2C, NLPO ( RL4LM’s novel approach )
  • Actor-Critic Policies supporting causal LMs

Other papers using similar approaches

CodeRL

CodeRL is a research model by Salesforce Research. Paper can be found here. Its an enhancement to their earlier code generation model CodeT5. The basic idea is most of the code generation models are open loop i.e they are not trained to produce logically correct or executable code as the model doesn’t have any feedback on the code correctness. CodeRL tries to solve this issue by fine-tuning the pre-trained CodeT5 models using RL to improve the code generation quality.

The key ideas here are

  • Using the code model as the Policy network
  • create a Critic network to predict the correctness of the generated program ( this works even for incomplete programs )
  • use Unit tests to check the correctness of the final generated program

RLPrompt

RLPrompt takes a different approach to the fine-tuning. Paper can be found here. Instead of fine-tuning the model the approach fine-tunes the prompts given to the model. As we know search for prompts is a trial and error approach and the kind of prompts used impacts the model performance.

They key ideas of this approach are

  • Use of a Policy for prompting ( This is also a language model but a smaller one ). A tunable layer is appended to the Policy network to enable fine-tuning
  • The generated prompt from the Policy is fed into the generation language model ( this can be a masked LM or decoder type LM )
  • Reward is constructed using the output of the generation language model

References

--

--