Exploring the power of Cohere.ai for text classification with a small dataset

Firi Berhane
5 min readAug 11, 2023

--

When it comes to text classification with small datasets, traditional AI methods often struggle to learn sufficient patterns and fail to generalize well. However, we can harness the power of large language models (LLMs) like those from OpenAI and cohere.ai. These models are pre-trained on vast amounts of text data, enabling them to capture complex language patterns and contextual information. In this article, we’ll explore the potential of cohere.ai’s generative LLM for text classification using a small dataset.

For our experiment, we’ll be working with the clickbait dataset, a collection of 32,000 headlines from sources like Wikinews, the New York Times, BuzzFeed, and more. These headlines are labeled as clickbait or non-clickbait. To keep things exciting, we’ll only use a fraction of the dataset: 5% (1,600 records) for training and 50% (15,174 records) for evaluation. It’s a challenge worth tackling if we can accurately classify headlines using such a small training set. But why cohere.ai? Well, while you often hear about OpenAI models, I wanted to explore something less commonly demonstrated in the AI community.

Fine-tuning cohere.ai’s Generative Model:

Before anything, we’ll have to make the necessary installations and imports.

!pip install cohere
import pandas as pd
import cohere
import time
from cohere.custom_model_dataset import CsvDataset, InMemoryDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_scor

In order to work with cohere.ai models in python, you have to sign up to cohere.ai and get your API key first.

api_key = 'YOUR_COHEREAI_API_KEY_HERE'
co = cohere.Client(api_key)

Preparing the Data for Fine-tuning:

To get started with the data processing, we’ll import and prepare the dataset by converting the “headlines” column into prompts and the “clickbait” column into completions. These prompt-completion pairs will serve as the foundation for fine-tuning.

clickbait_df=pd.read_csv("clickbait_data.csv")
clickbait_df=clickbait_df.rename(columns={'headline':'Prompt','clickbait':'Completion'})

In the dataset, the value for clickbait is represented by 1, and non-clickbait by 0. So, the numbers are replaced by “clickbait” and “not clickbait” correspondingly. Additionally, we’ll add a clear instruction prompt by appending the suffix “\n Is the above statement a clickbait or not?\n\n###\n\n” to each headline. The “\n\n###\n\n” is there in order to indicate to the model where the prompt ends and the completion text begins when it’s put to use later on.

clickbait_df['Completion'].replace({1: ' clickbait\n', 0: ' not clickbait\n'}, inplace=True)
clickbait_df['Prompt']=clickbait_df['Prompt']+'/n is the above statement a clickbait or not?\n\n###\n\n'

Next, we’ll split the dataset into training and validation sets. We’ll allocate 5% for training and 50% for evaluation. The training set will be used to fine-tune the model, while the validation set will assess its performance.

#a function for splitting the dataset by the fraction given as input
def get_dec_percent(x, dec, exclude=None):
if exclude is not None:
x = x[~x.index.isin(exclude.index)]
return x.sample(frac=dec)

df_train = clickbait_df.groupby('Completion').apply(get_dec_percent, 0.05)
df_validation = clickbait_df.groupby('Completion').apply(get_dec_percent, 0.50, exclude=df_train)

To facilitate the fine-tuning process, we’ll convert the training and validation sets into InMemoryDataset, which is an instance that can be created from the cohere.custom_model_dataset library. This allows us to directly feed the data into cohere.ai without saving it. We’ll also use the validation set to obtain evaluation metrics, giving us insights into the model’s performance on unseen prompts.

data_list = [tuple(row) for row in df_train.values]
eval_data_list = [tuple(row) for row in df_validation.values]
dataset_with_eval = InMemoryDataset(training_data = data_list, eval_data=eval_data_list)

Fine-tuning the Model:

Using cohere.ai’s functionality, we’ll create a custom model named “prompt-completion-ft-with-eval” for fine-tuning. This model uses the prepared dataset and the power of cohere.ai’s generative language model. The goal is to train, and evaluate the model on the clickbait classification task, capturing the nuances and patterns necessary for accurate predictions.

finetune_with_eval_ = co.create_custom_model("prompt-completion-ft-with-eval", dataset=dataset_with_eval, model_type="GENERATIVE")

You can go to your cohere.ai account dashboard to check how the model performs after it’s done being fine-tuned. The fine-tuning process is relatively fast too considering that it waits in a queue for the most part. This one took about an hour or so.

It’s amazing how it correctly classified all the completion values with an accuracy of 100%.

Evaluating the model’s performance:

We can see from the dashboard that it gave perfect scores but we have to test how well this actually works outside the fine-tuning process. So, we’ll use a separate test dataset, consisting of 1% of the clickbait dataset (320 prompt-completion pairs). Here, We’ll load the test dataset and create lists to store predicted and actual labels. The model will generate completions for each prompt, and we’ll compare them with the expected completions from the test dataset. When we look at a bunch of metrics like accuracy, precision, recall, and F1-score, we can dig deeper into how the model is doing and see if it meets our evaluation standards as we can gain deeper insights into the model’s performance.

df_test = clickbait_df.groupby('Completion').apply(get_dec_percent, 0.01, exclude=pd.concat([df_train, df_validation]))
predicted_labels = []
actual_labels = []
# change the test dataframe into a list of dictionaries
test_list=df_test.to_dict('records')
# Generate text using the fine-tuned model
for item in test_list:
response = co.generate(
model="YOUR MODEL ID HERE",
prompt=item['Prompt'],
stop_sequences=["\n\n###\n\n"],
return_likelihoods='NONE')

generated_completion = response.generations[0].text

# as 'completion' column contains the expected completion in the test dataset
expected_completion = item['Completion']

predicted_labels.append(generated_completion)
actual_labels.append(expected_completion)
i=i+1
print(i)
time.sleep(25)
# Calculate evaluation metrics
accuracy = accuracy_score(actual_labels, predicted_labels)
precision = precision_score(actual_labels, predicted_labels, average='weighted')
recall = recall_score(actual_labels, predicted_labels, average='weighted')
f1 = f1_score(actual_labels, predicted_labels, average='weighted')

Evaluation Results:

The results are in, and they’re looking great! The model achieved a score of approximately 99.38% across all evaluation metrics, including accuracy, precision, recall, and F1-score, which is a pretty good score! The fact that all the scores are of the same value indicates that the model’s performance is consistent across all metrics. This kind of result usually occurs when the data is balanced. In our case, the test data has an equal amount for both of the labels. Overall, the scores are indicative of the model’s power and its ability to handle text classification tasks, even with limited training data.

Conclusion

In this article, we explored how well cohere.ai’s generative language model works for text classification with a small dataset. We witnessed how these LLMs can outperform traditional AI methods, achieving great results even when trained on just a fraction of the clickbait dataset. With its pre-trained knowledge and capacity to capture intricate language patterns, cohere.ai’s model opens new doors for accurate text classification.

Here is the github link to the complete code.

Sources

  1. “Training Custom Models.” Cohere AI, docs.cohere.com/docs/training-custom-models. Accessed 11 Aug. 2023.
  2. Iriondo, Roberto. “Fine-Tuning with Cohere: Part 4 — Unlocking the Power of Custom Models for Next-Level AI Generation.” Medium, 18 Feb. 2023, txt.generativeailab.org/fine-tuning-with-cohere-part-4-unlocking-the-power-of-custom-models-for-next-level-ai-generation-5dc201ceafb5. Accessed 11 Aug. 2023.
  3. Zhukov, Viacheslav. “Text Classification Challenge with Extra-Small Datasets: Fine-Tuning versus ChatGPT.” Medium, 7 July 2023, towardsdatascience.com/text-classification-challenge-with-extra-small-datasets-fine-tuning-versus-chatgpt-6348fecea357. Accessed 11 Aug. 2023.
  4. “Creating Custom Generative Models.” Cohere AI, docs.cohere.com/docs/creating-custom-models. Accessed 11 Aug. 2023.

--

--