A hands-on guide to using TRLx for Text Summarisation with Reinforcement Learning from Human Feedback

Ben Burtenshaw
4 min readMar 24, 2023

--

Graphic on RLHF from Stiennon et al. 2020 https://arxiv.org/abs/2009.01325

This post is a practical guide to implementing a text summarisation tool using the Reinforcement Learning From Human Feeback (RLHF) method. Researchers from OpenAI published their paper ‘Learning to summarize from human feedback’, which applied RLHF to GPT models (Stiennon et al. 2020). This post will explore implementing RLHF using TRLx, a recent package from CarperAI based on Transformers Reinforcement Learning from Hugging Face.

The recent introduction of OpenAI’s ChatGPT has generated significant interest in the RLHF technique within language modelling communities and beyond 🌐. OpenAI’s paper on Learning to Summarize with RLHF showed the suboptimal performance of fine-tuning on summarization data, suggesting that optimizing for human preferences through reinforcement learning is a better approach. This post aims to replicate the results of OpenAI’s paper using the trlX library.

TRLx is a distributed training framework inspired by the Transformer Reinforcement Learning library. Focusing on RLHF at scale, trlX is an excellent tool for reproducing many recent RLHF literature findings. It currently supports Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) algorithms, allowing researchers to concentrate on high-level reinforcement learning dynamics rather than boilerplate code for distributed training.

1. Fine-tuning a pre-trained transformer model on our summarization dataset:

First, we will fine-tune a pre-trained transformer model for text summarization using the trlX library and Hugging Face Transformers. We will use the T5-small model, which is a lightweight version of the T5 transformer model designed for generative tasks like summarization. The trlX library will help us easily fine-tune the model using our custom dataset, abstracting away the training loop, optimization, and other details.

import trlX
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

# Load your dataset
train_dataset, val_dataset = load_summarization_datasets()

# Fine-tune the model
trainer = trlX.Trainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
val_dataset=val_dataset,
train_batch_size=8,
gradient_accumulation_steps=2,
)

trainer.train()

The code begins by importing the necessary libraries and then loading the T5-small model and its tokenizer. Next, we load our custom summarization datasets for training and validation. Afterward, we create a Trainer instance from the trlX library, providing it with the model, tokenizer, training dataset, validation dataset, training batch size, and gradient accumulation steps. Finally, we call the trainer.train() method to initiate the fine-tuning process.

2. Training a reward model:

Next, let’s train a reward model using the fine-tuned transformer model and the comparison dataset. The reward model will help us estimate the quality of generated summaries in the reinforcement learning process. We will utilize the trlX library to train the reward model, which simplifies the training process and makes it more efficient.

We need a supervised dataset to train the reward model, and for fine-tuning the pre-trained supervised model, we use a scientific TL;DR dataset from AllenAI.

from trlX.reward_model import RewardModel
from datasets import load_dataset

reward_model = RewardModel(model, tokenizer)
comparison_dataset = load_dataset("allenai/scitldr")

# Train the reward model
reward_trainer = trlX.RewardModelTrainer(
reward_model=reward_model,
train_dataset=comparison_dataset,
train_batch_size=8,
)

reward_trainer.train()

In the code above, we create a RewardModel using the fine-tuned transformer model and its tokenizer. We load the comparison dataset, which contains pairs of summaries with their corresponding quality scores. And we create a RewardModelTrainer instance, providing it with the reward model, the comparison dataset, and the training batch size.

3. Fine-tuning the model using PPO:

Finally, we will apply Proximal Policy Optimization (PPO) to fine-tune the transformer model using the trained reward model as guidance for reinforcement learning. PPO is a popular and efficient reinforcement learning algorithm that helps optimize policies in complex environments. The trlX library makes it easy to use PPO for fine-tuning text summarization models with the help of the reward model.

ppo_trainer = trlX.PPOTrainer(
model=model,
tokenizer=tokenizer,
reward_model=reward_model,
train_dataset=train_dataset,
train_batch_size=8,
)

ppo_trainer.train()

In the code above, we create an instance of the PPOTrainer class, providing it with the model assets and parameters. The PPOTrainer will use the reward model to guide the fine-tuning process of the transformer model, optimizing it to generate better summaries based on human preferences. Finally, we call the ppo_trainer.train() method to start the PPO-based fine-tuning process.

Conclusion

In this post, we used trlX to implement RLHF for a summarization task, following three steps: fine-tuning a pre-trained transformer model on our summarization dataset, training a reward model, and using the reward model to fine-tune the model via PPO. We used the TL;DR summarization dataset from OpenAI’s Learning to Summarize with RLHF paper.

References

GitHub — CarperAI/trlx: A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Implementing RLHF: Learning to Summarize with trlX

Welcome to trlX’s documentation! — trlX documentation

Illustrating Reinforcement Learning from Human Feedback (RLHF)

RLHF, ‘online’ ML systems, and RL going mainstream

Understanding Reinforcement Learning from Human Feedback (RLHF): Part 1

RLHF — LessWrong

Aligning language models to follow instructions — OpenAI

--

--