Data-Centric AI in Action: Boosting Model Performance with Synthetic Data from Large Language Models
In my previous article, I talked about how to do error analysis on machine learning models and decide on what kind of data is needed to improve the model. In the case study discussed there, we assumed that the data (unlabeled) is available to choose from a data dump. But, what if you do not have access to such a data dump to augment your data? In the following, you will see that in such a scenario, assuming that you are working with text data, it is possible to use Large Language Models (LLMs) to generate synthetic text data that will make the model perform as good as if you had additional real-world data.
Organization of the article
Here is how the article will be ordered.
1. Case study description
2. Solution approach
3. Hands-on solution
All the codes used in the article, can be found in this github repo. Hope you enjoy the discussion!
The case study problem
The case study will be similar to my previous article with a little twist towards the end. Consider you have a limited set of annotated data, using which you have built an ML model. You have done your hyperparameter tuning and the model has reached its limit, i.e., you are unable to improve your model metrics any more just by changing the algorithm or its parameters. You are now planning to add more data to the model (data augmentation), but you neither have any extra data at your disposal, nor have the time or money for outsourcing the annotation task. What should be your approach to generate new data to improve your model?
The solution approach
The solution method for these type of problems generally take the steps depicted in Fig. 1. First, you understand which kinds of data led to your model to perform badly. In classification problems, looking at the precision, recall and confusion matrix will help, as we will see in the example to be discussed. Once you identify the data points on which your model did not perform well, you can use those to prompt the large language model in few-shot learning method to generate similar text data.
This solution approach ensures that, for annotation, you do not use the type of data that your model already performs well. It saves on the annotation cost and time. Next, we will go through the hands-on demonstration of the above problem using a Google-colab notebook. You can access the code in this github repo.
The solution
The initial data preparation and baseline model building step is same as followed in the previous article. For completeness, I have kept it separately in this notebook.
Setting up the scenario: We will use the well-known ham-spam dataset from Huggingface to set up the use case. First, we will split the dataset and keep 40% of it as the data dump or pool, rest of the 60% will be split in a 60%-40% train-test split. Some data is kept aside to make sure the model performs badly helping us build the use case. The notebook can be found here.
# Load the dataset
dataset = load_dataset("sms_spam")
df = pd.DataFrame(dataset['train'])
# check class distribution
df['label'].value_counts(normalize = True).round(2) # we already have an imbalanced dataset where your spam is 13% and ham is 87%
def sample_df(df, train_fraction, random_state = 123):
sample_size = int(len(df)*train_fraction)
train_sample = df.sample(n=sample_size, random_state=42)
test_sample = df.drop(train_sample.index)
return train_sample, test_sample
traintest_sample, pool_sample = sample_df(df = df, train_fraction = 0.6, random_state = 123)
train_sample, test_sample = sample_df(df = traintest_sample, train_fraction = 0.6, random_state = 123)
Create embeddings and train the baseline model: Next, we will use the train data to build TF-IDF embeddings (fit on train data and transform on test data). See below:
# Initialize the TF-IDF Vectorizer. To reduce the number of features and the sparsity, I have used min_df = 0.01.
# This means we will only consider tokens that appear in more than 1% of the training data
vectorizer = TfidfVectorizer(min_df = 0.01)
# Fit and transform the text data to create TF-IDF vectors
train_tfidf_mat = vectorizer.fit_transform(train_sample['sms'])
test_tfidf_mat = vectorizer.transform(test_sample['sms'])
# baseline model
# Initialize the Multinomial Naive Bayes classifier
nb_classifier = MultinomialNB()
# Train the classifier on the training data
nb_classifier.fit(train_tfidf_mat, train_sample['label'])
# Make predictions on the testing data
y_pred = nb_classifier.predict(test_tfidf_mat)
# Calculate the accuracy of the classifier
accuracy = accuracy_score(test_sample['label'], y_pred)
print(f"Accuracy: {accuracy.round(2)}")
# Print the classification report
report = classification_report(test_sample['label'], y_pred)
print("Classification Report:")
print(report)
And we get the following baseline:
Accuracy: 0.97
Classification Report:
precision recall f1-score support
0 0.97 1.00 0.99 1156
1 0.99 0.82 0.89 182
accuracy 0.97 1338
macro avg 0.98 0.91 0.94 1338
weighted avg 0.97 0.97 0.97 1338
Find the issues and the data that can address them: As expected, the model did not perform very well for detecting spam cases, especially, the recall for spam cases are low. So, any data augmentation should focus on improving recall. Now, recall is given as TP/(TP+FN), where TP = True Positive, and FN = False Negative. Increasing recall means decreasing false negatives. What we can therefore do is to find out the false negatives (FNs) in the training data:
# We will check false negatives, i.e., where the y_true = 1 but y_pred = 0.
FN_cases = test_sample[(test_sample['label'] == 1) & (y_pred == 0)]
# let's save these FN cases, so that these can be used later
FN_cases.to_csv("FN_cases.csv", index = False)
print(len(FN_cases))
FN_cases.head()
33
sms label
0 22 days to kick off! For Euro2004 U will be ke... 1
1 Twinks, bears, scallies, skins and jocks are c... 1
2 Will u meet ur dream partner soon? Is ur caree... 1
3 Hello darling how are you today? I would love ... 1
4 Check Out Choose Your Babe Videos @ sms.shsex.... 1
Creating the prompt: Now, we will use these 33 FN observations and pass them as examples to the LLM to generate similar data. For that we need to create the prompt. We will write our own function to generate the prompt. The code can be found here. The function should have the capability to do the following:
- accept the FN cases as input to the function
- sample n_examples observations from the FN cases in a bootstrapped fashion (sampling with replacement)
- prompt the LLM using the n_examples to generate n_out number of similar examples
Let’s see how it is done:
import random
def generate_prompt(FN_cases_df, n_examples, n_out):
prompt_header = f"""
Consider yourself as a synthetic data generator.
I will be giving you 'spam' messages from a ham-spam dataset.
These specific spam messages are where the model did not perform well.
You need to generate more such spam messages for the model to retrain.
Here are {n_examples} examples of the spam messages:
"""
prompt_footer = f"""
Now, generate {n_out} more such spam messages.
Output each generated meassage in following format:
<START> <genrated_message> <END>
"""
ind_list = random.choices(FN_cases_df.index, k= n_examples)
prompt_body = """\n"""
for i, j in enumerate(ind_list):
prompt_body = prompt_body + f"Example {i}:\n" + FN_cases_df['sms'].loc[j] + "\n"
final_prompt = prompt_header + prompt_body + prompt_footer
return final_prompt
test_prompt = generate_prompt(FN_cases_df = FN_cases_df, n_examples = 5, n_out = 10)
print(test_prompt)
Consider yourself as a synthetic data generator.
I will be giving you 'spam' messages from a ham-spam dataset.
These specific spam messages are where the model did not perform well.
You need to generate more such spam messages for the model to retrain.
Here are 5 examples of the spam messages:
Example 0:
Hi its LUCY Hubby at meetins all day Fri & I will B alone at hotel U fancy cumin over? Pls leave msg 2day 09099726395 Lucy x Calls£1/minMobsmoreLKPOBOX177HP51FL
Example 1:
Dont forget you can place as many FREE Requests with 1stchoice.co.uk as you wish. For more Information call 08707808226.
Example 2:
Hi ya babe x u 4goten bout me?' scammers getting smart..Though this is a regular vodafone no, if you respond you get further prem rate msg/subscription. Other nos used also. Beware!
Example 3:
FREE2DAY sexy St George's Day pic of Jordan!Txt PIC to 89080 dont miss out, then every wk a saucy celeb!4 more pics c PocketBabe.co.uk 0870241182716 £3/wk
Example 4:
TheMob>Hit the link to get a premium Pink Panther game, the new no. 1 from Sugababes, a crazy Zebra animation or a badass Hoody wallpaper-all 4 FREE!
Now, generate 10 more such spam messages.
Output each generated meassage in following format:
<START> <genrated_message> <END>
Synthetic data generation using LLM: The above prompt can now be passed to an LLM to generate synthetic FN observations. We will use the open-source Phi3-mini model from Microsoft, which is a relatively small (3.8 Billion parameters) yet powerful model and can be easily loaded with 4 bit quantization on a T4 GPU (12 GB) freely available with Google colab. Also, we will use langchain pipeline to pass the prompt to the LLM and run it. Here’s how to do it:
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline,BitsAndBytesConfig
from langchain_core.prompts import PromptTemplate
# to load the model in 4bits
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
# Initialize the model
model_id = "microsoft/Phi-3-mini-128k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
quantization_config=bnb_config,
device_map={"":0},
trust_remote_code=True
)
# create pipeline instance
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer,temperature=0.5, max_new_tokens=4000)
hf = HuggingFacePipeline(pipeline=pipe)
# define template inside which our generated prompt will be passed
template = """<|user|>\n{question}<|end|>\n<|assistant|>"""
prompt = PromptTemplate.from_template(template)
# creating chain of prompt hf pipeline
chain = prompt | hf
Our pipeline is ready! Now, we will go ahead with the generation. We will run it 5 times, each time asking it to generate 10 similar observations. Make sure to save the results!
# now we run the llm using our prompt and save the results
import os
from tqdm import tqdm
file_path = "/content/llm_output.txt"
for c in tqdm(range(5)):
prompt = generate_prompt(FN_cases_df = FN_cases_df, n_examples = 5, n_out = 10)
output = chain.invoke({"question": prompt})
if os.path.exists(file_path):
with open(file_path, 'a') as file:
file.write(output + '\n ####')
else:
with open(file_path, 'w') as file:
file.write(output + '\n ####')
Response extraction and cleaning: The raw response from the LLM cannot be passed directly to the model. We will first extract the observations (using the <START> and <END> tokens) from the response and then clean them to remove any additional tokens not part of the observations. The LLM also hallucinated generating duplicate observations, which too need to be removed. The following extraction and post-processing functions can be used to perform the above operations (the code is available in this notebook):
def extract_responses(file_path, start_tag="<START>", end_tag="<END>"):
"""
- We will read in the llm out as text file a text file
- extract lines situated between lines containing <START> and <END> markers.
"""
extracted_lines = []
with open(file_path, 'r') as file:
capture = False
for line in file:
line = line.strip()
if start_tag in line:
capture = True
extracted_lines.append(line.replace(start_tag, '').strip())
elif end_tag in line:
capture = False
extracted_lines[-1] += ' ' + line.replace(end_tag, '').strip()
elif capture:
extracted_lines[-1] += ' ' + line.strip()
return extracted_lines
def post_process(extracted_lines):
"""
- remove the list entry that contains the '<|user|>' or '<|end|>' tags
- remove special tags like '<END>', '####', '<|assistant|> '
- deduplicate list entries (there is a lot of duplicated entries due to hallucination)
Note: these post processing steps can vary based on the prompt and the model
"""
processed_lines = []
seen_lines = set()
for line in extracted_lines:
# Remove entries that contain both <START> and <END> tags
if '<|user|>' in line or '<|end|>' in line:
continue
# Remove special tags like '<|assistant|>', '<END>', '####'
line = line.replace('<|assistant|>', '').replace('<END>', '').replace('####', '').strip()
# Skip if the line is empty
if not line:
continue
# Deduplicate entries
if line not in seen_lines:
seen_lines.add(line)
processed_lines.append(line)
return processed_lines
Note, these extraction or post-processing functions are not general and need to be adapted to the prompt and the response that the LLM generates. Now, we extract and clean the response, and see how the it looks:
file_path = "/content/llm_output_final.txt"
#extraction
extracted_response = extract_responses(file_path = file_path, start_tag="<START>", end_tag="<END>")
#clean llm output
cleaned_lines = post_process(extracted_lines = extracted_response)
# save the cleaned output
file_path = '/content/cleaned_llm_output.txt'
with open(file_path, 'w') as file:
for lines in cleaned_lines:
file.write(lines + '\n')
# to read the saved output
file_path = '/content/cleaned_llm_output.txt'
with open(file_path, 'r') as file:
lines = [line.strip() for line in file if line.strip()]
print("total examples", len(lines))
import pprint as pp
pp.pprint(lines)
A glimpse of the cleaned synthetic observations:
total examples 41
["Hey there! I've got a once-in-a-lifetime deal for you! Click on this link "
"and get your hands on a rare, limited edition collector's item. Hurry up, "
"this offer won't last forever!",
"Attention all O2 users! You've been selected for an exclusive offer. Click "
"on this link and claim your free upgrade to the premium plan. Don't miss out "
'on this amazing opportunity!',
"URGENT! You've been chosen to receive a special gift from your favorite "
'celebrity. Click on this link to claim your prize and get a sneak peek into '
'their personal life. This is a once-in-a-lifetime opportunity!',
"Hello, dear friend! I'm reaching out to you with a special offer. Click on "
'this link and get a free vacation package to your dream destination. Hurry '
"up, this offer won't last forever!", ...]
They actually look like spam messages! Note, we generated 50 observations, but after cleaning only 41 are left, rest were duplicates — a result of the hallucinations. Generally, smaller models are prone to more hallucinations.
Retraining with synthetic data: We can now go ahead and add this data to the training dataset and retrain the model. Note, all the generated observations are FN cases and hence have label ‘1’ (spam).
# reading in the train and test datasets created before
train_sample = pd.read_csv('/content/train.csv')
test_sample = pd.read_csv('/content/test.csv')
# creating dataframe to add to the train data (adding the spam label)
df_to_add = pd.DataFrame({'sms': lines, 'label': [1] * len(lines)})
# concat the df with the train data
aug_train_sample = pd.concat([train_sample, df_to_add], ignore_index=True)
We perform the usual steps now, i.e., vectorize the augmented train set and the test set and retrain the model:
# vectorize the data
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score, classification_report
vectorizer = TfidfVectorizer(min_df = 0.01)
# Fit and transform the text data to create TF-IDF vectors
train_tfidf_mat = vectorizer.fit_transform(aug_train_sample['sms'])
test_tfidf_mat = vectorizer.transform(test_sample['sms'])
## retrain model with augmented data
# Initialize the Multinomial Naive Bayes classifier
nb_classifier = MultinomialNB()
# Train the classifier on the training data
nb_classifier.fit(train_tfidf_mat, aug_train_sample['label'])
# Make predictions on the testing data
y_pred = nb_classifier.predict(test_tfidf_mat)
# Calculate the accuracy of the classifier
accuracy = accuracy_score(test_sample['label'], y_pred)
print(f"Accuracy: {accuracy.round(2)}")
# Print the classification report
report = classification_report(test_sample['label'], y_pred)
print("Classification Report:")
print(report)
Accuracy: 0.98
Classification Report:
precision recall f1-score support
0 0.98 1.00 0.99 1156
1 0.97 0.87 0.92 182
accuracy 0.98 1338
macro avg 0.97 0.93 0.95 1338
weighted avg 0.98 0.98 0.98 1338
It worked! As you can see, with only 41 synthetic FN observations added to the training data, the recall improved by 5% and the F1-score by 3%. If we compare this with our previous approach, where we collected 66 similar data points from the data dump (similar to the FN cases), we actually achieved the same f1-score of 92%. Therefore, at least for this use case the LLM is able to generate observations that are as good as the original data.
Conclusion
To summarize, we looked at how to do error analysis and generate relevant synthetic data using open-source LLMs that work as good as the real-world data.
I hope you enjoyed the discussion and if so please like this article, follow me on Linkedin and put a star in the corresponding GitHubrepo of this article.