Simple RAG with OpenAI and scikit-learn

Alexandr Dzhumurat
Python’s Gurus
Published in
4 min readJun 4, 2024

RAG (Retrieval-Augmented Generation) is an advanced approach in natural language processing that combines the strengths of information retrieval and text generation models. Traditional text generation models, like GPT, generate responses based on patterns learned from vast amounts of data, but they might lack specific, up-to-date information. RAG addresses this by first retrieving relevant documents or snippets from a database, and then using these retrieved documents to inform and enhance the text generation process. This hybrid method ensures that the generated text is both contextually relevant and factually accurate. By integrating retrieval mechanisms, RAG models can dynamically access and incorporate the latest information, making them highly effective for applications like question answering, customer support, and content creation. Let’s build RAG app for medicine recommendations!

Photo by Leonardo Toshiro Okubo on Unsplash

We will use amazon reviews dataset. The plan is to build RAG over one dataset (healthcare)

First step is embed our text description. We will use transformer-based model as an embedder.

import gzip
import json

def read_raw_data(file_name, limit: int, fields = None):

file_path = os.path.join(root_data_dir, file_name)
res = []
with gzip.open(file_path, 'rt') as gz_file:
for line in gz_file:
data = json.loads(line.strip())
if fields is not None:
res.append({i: j for i, j in data.items() if i in fields})
else:
res.append(data)
if limit == len(res):
break
print('Dataset num items: %d' % len(res))
return res

def load_corpus(db):
corpus_texts = []
for item in db:
corpus_texts.append(item['text'])
return corpus_texts

def train_embeds(corpus_texts, embedder, sentence_embedding_path, overwrite=False):
if os.path.exists(sentence_embedding_path) and not overwrite:
print('corpus loading from %s' % sentence_embedding_path)
passage_embeddings = np.load(sentence_embedding_path)
else:
print('num rows %d' % len(corpus_texts))
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True)
passage_embeddings = np.array([embedding for embedding in passage_embeddings]).astype("float32")
with open(sentence_embedding_path, 'wb') as f:
np.save(f, passage_embeddings)
print('corpus saved to %s' % sentence_embedding_path)
print('Num embeddings %d' % passage_embeddings.shape[0])
return passage_embeddings

db = read_raw_data(
'Health_and_Personal_Care.jsonl.gz',
limit = -1, fields = ['rating', 'text', 'title', 'asin']
)
corpus = load_corpus(db)
data_version = 0
embeds = train_embeds(
corpus, model,
os.path.join(root_data_dir, f'corpus_embeds_{data_version}.npy'),
overwrite=False
)

The next step is retrieval. We will use just `cosine_similarity` as an search engine


def top_similar(query_embed, candidates_embeds, top=10):
from sklearn.metrics.pairwise import cosine_similarity

sims = cosine_similarity(query_embed.reshape(1, -1), candidates_embeds)[0]
top_similar_idx = [int(i) for i in np.argsort(-np.abs(sims))][:top]
return top_similar_idx

def retrieve(query):
query_embed = model.encode(query, show_progress_bar=False)
indexes = top_similar(query_embed, embeds)
return indexes

def get_retrieved_content(db, ids):
res = [{db[i]['asin']: db[i]['text']} for i in ids]
return res


query = "headache tablets"
ids = retrieve(query)
get_retrieved_content(db, ids)

The last step is “generation” — we need to construct prompt based on retrieved data


import datetime
import hashlib

import backoff
import openai
from openai import OpenAI


client = OpenAI(
api_key=envs["OPENAI_API_KEY"]
)


@backoff.on_exception(backoff.expo, openai.APIError)
@backoff.on_exception(backoff.expo, openai.RateLimitError)
@backoff.on_exception(backoff.expo,openai.Timeout)
@backoff.on_exception(backoff.expo, RuntimeError)
def gpt_query(gpt_params, verbose: bool = False, avoid_fuckup: bool = False) -> dict:
print('connecting OpenAI...')
if verbose:
print(gpt_params["messages"][1]["content"])
response = client.chat.completions.create(
**gpt_params
)
gpt_response = response.choices[0].message.content
if avoid_fuckup:
if '[' in gpt_response or '?' in gpt_response or '{' in gpt_response:
raise RuntimeError
res = {'recs': gpt_response}
res.update({'prompt_tokens': response.usage.completion_tokens, 'prompt_tokens': response.usage.prompt_tokens, 'total_tokens': response.usage.total_tokens})
return res

def gen_candadates(db, ids):
candidates = '\n'.join(["item_id: %s; review: %s" % (db[i]['asin'], db[i]['text']) for i in ids])
return candidates

def promt_generation(candidates):
# TODO: use jinja2
promt = f"""
Next rows below is an item_id reviews.
{candidates}
Utilize reviews to determine the best item_id.
Avoid including actual reviews; rephrase them succinctly.
Keep the recommendation under 50 words. Avoid starting with "Based on reviews"; opt for a more creative approach!
Recommendation:
"""
return promt

def generate(db, ids, verbose=False):
gpt_params = {
'model': 'gpt-3.5-turbo',
'max_tokens': 500,
'temperature': 0.7,
'top_p': 0.5,
'frequency_penalty': 0.5,
}
gpt_promt = promt_generation(gen_candadates(db, ids))
if verbose:
print(gpt_promt)
messages = [
{
"role": "system",
"content": "You are a helpful assistant for medicine shopping",
},
{
"role": "user",
"content": gpt_promt,
},
]
gpt_params.update({'messages': messages})
res = gpt_query(gpt_params, verbose=False)

return res

The results is amazing!

LLM generated relevant answer based on retrieved documents. As you can see model chosed most frequent document ID and generates an argument based on the review.

In conclusion, the advent of Retrieval-Augmented Generation (RAG) marks a pivotal development in the evolution of AI-driven text generation. By seamlessly integrating the precision of information retrieval with the creative capabilities of language models, RAG offers a robust solution for generating content that is both accurate and contextually relevant.

Python’s Gurus🚀

Thank you for being a part of the Python’s Gurus community!

Before you go:

  • Be sure to clap x50 time and follow the writer ️👏️️
  • Follow us: Newsletter
  • Do you aspire to become a Guru too? Submit your best article or draft to reach our audience.

--

--