(Part 1) Build your own RAG with Mistral-7B and LangChain

Madhav Thaker
14 min readNov 14, 2023

--

DALL-E generated image of a fantasy football assistant

LLMs have taken the world by storm and rightfully so. They help you with every day tasks like building a coding project, creating a recipe, or even writing a Medium article. While this is great, there are still major limitations that still exist. In my opinion, there are two key issues:

  • They do not have access to real time information.
  • They are prone to hallucinations.

I’m sure many of you have seen a response like this where it is not able to generate a response due to its training data:

https://chat.openai.com/

Or when they confidently suggest a solution which isn’t correct:

https://chat.openai.com/

We are at the mercy of when new models are released with expanded training data. But what if we want an assistant that can answer a question about something that happened in 2023? What about asking questions about something that happened this week? Or what if you want the LLM to have more information so that we can improve response accuracy?

In these scenarios, fine tuning is an option but comes with its own risks/challenges:

  1. Model Drift: Over time, as the model is continuously fine-tuned with new data, it might start to drift from its original performance and behavior. This could lead to unexpected and undesirable results.
  2. Costly and Complex: This approach not only presents significant technical challenges, but it also incurs substantial costs. The need to fine-tune our model on a weekly basis would require a considerable investment in terms of computational resources and expert manpower, making it a complex and expensive endeavor.

So how do we address this? One approach is Retrieval Augmented Generation (RAG). In Part 1 of this RAG series, we’ll cover:

  1. What are RAGs?
  2. How do they work?
  3. How to leverage Mistral 7b via HuggingFace and LangChain to build your own.
  4. Real examples of a small RAG in action!

For my use case, I’m going to attempt to create a Fantasy Football (NFL) assistant that can answer questions on what is currently happening in the season (something existing LLMs can’t do!).

Before jumping in, there are a few key concepts in this article that may not be familiar to everyone. Just in case, I’ve highlighted them below along with a few helpful resources to get you up to speed:

  1. Text embeddings:

2. Vector databases:

For those of you who want to jump in and start prototyping, I’ve included the end to end code at the end of this article.

RAGs

What is a RAG?

Retrieval Augmented Generation (RAG). Simply put, RAGs help LLMs by giving them access to external data so that they can generate a response with additional context. This context can be anything from recent news, audio transcripts of a lecture, or in my case — fantasy football news.

How do they work?

You can think of RAG as an LLM with vector search attached. Here’s a high-level diagram to illustrate how they work:

High Level RAG Architecture

Here are the 4 key steps that take place:

  1. Load a vector database with encoded documents.
  2. Encode the query into a vector using a sentence transformer.
  3. Based on the inputted query, retrieve relevant context from thevector database.
  4. Leverage context along with the query to prompt the LLM.

We’ll dive into all of this throughout the rest of the blog.

How do we build it?

For this exercise, we’ll take a handful of articles from fantasypros.com and ask the LLM questions that it would only be able to answer if it has access to that data.

https://www.fantasypros.com/2023/11/nfl-week-10-sleeper-picks-player-predictions-2023/

Environment Details

Compute: 1 Nvidia L4 GPU

Dependencies:

!pip install -q torch datasets
!pip install -q accelerate==0.21.0 \
peft==0.4.0 \
bitsandbytes==0.40.2 \
transformers==4.31.0 \
trl==0.4.7

Load a quantized Mistral-7B Model

We are going be using Hugging Face to load our quantized Mistral-7B model. Check out Mistral AI’s announcement article to learn more about their breakthrough 7B model.

#################################################################
# Tokenizer
#################################################################

model_name='mistralai/Mistral-7B-Instruct-v0.1'

model_config = transformers.AutoConfig.from_pretrained(
model_name,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
#################################################################
# bitsandbytes parameters
#################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False
#################################################################
# Set up quantization config
#################################################################
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16: accelerate training with bf16=True")
print("=" * 80)
#################################################################
# Load pre-trained config
#################################################################
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
)

Let’s try out the Mistral 7B by asking it what it knows about fantasy football.

inputs_not_chat = tokenizer.encode_plus("[INST] Tell me about fantasy football? [/INST]", return_tensors="pt")['input_ids'].to('cuda')

generated_ids = model.generate(inputs_not_chat,
max_new_tokens=1000,
do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)
<s> [INST] Tell me about fantasy football? [/INST] 

Fantasy football is a type of online game where players create and manage
their own virtual team, composed of real-life football players, to compete
against other teams in a simulated league. The game typically involves
selecting players for their team based on their perceived abilities and
potential for scoring points in the real-life football games that are
played throughout the season. Throughout the season, players can make
trades and adjustments to their team in an attempt to maximize their points
and win their league. Some popular fantasy football platforms include Yahoo
Fantasy Football, ESPN Fantasy Football, and NFL.com Fantasy Football.</s>

While this isn’t the most complex query, this level of coherence is still pretty impressive coming from a 7B model.

Quick tangent: Here’s an interesting function to show exactly how many trainable parameters you have access to using this quantization.

def print_number_of_trainable_model_parameters(model):
trainable_model_params = 0
all_model_params = 0
for _, param in model.named_parameters():
all_model_params += param.numel()
if param.requires_grad:
trainable_model_params += param.numel()
return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))
trainable model parameters: 262410240
all model parameters: 3752071168
percentage of trainable model parameters: 6.99%

Seeing this really drove home how critical quantization is when working with LLMs. By updating just 7% of the model parameters, we’re able to completely transform how an LLM behaves.

Ok, now back to building RAGs!

Creating a RAG using LangChain

For the purposes of this article, I’m going to create all of the necessary components using LangChain. This will give us what we need to build a quick end to end POC. You can find more information in their docs. It’s also worth mentioning that LangChain isn’t the only (or best) option for building RAGs.

For this exercise, we will be using FAISS to create our vector database. FAISS, or Facebook AI Similarity Search, is a library developed by Facebook AI that allows for efficient similarity search and clustering of dense vectors. It’s a powerful tool that, in this context, allows us to retrieve context from external sources. In this article, we’ll walk you through the process of creating a vector database using LangChain’s FAISS API.

Fortunately for us, LangChain has built-in capabilities to allow us to create and query our in index.

Here’s a more detailed diagram showing how each piece interacts. We’ll break down each piece and provide code examples:

Detailed RAG architecture using LangChain

Create Vector Database

First, we’ll walk through how to create the vector database.

from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import AsyncChromiumLoader
from langchain.document_transformers import Html2TextTransformer
from langchain.vectorstores import FAISS
import nest_asyncio

nest_asyncio.apply()

articles = ["https://www.fantasypros.com/2023/11/rival-fantasy-nfl-week-10/",
"https://www.fantasypros.com/2023/11/5-stats-to-know-before-setting-your-fantasy-lineup-week-10/",
"https://www.fantasypros.com/2023/11/nfl-week-10-sleeper-picks-player-predictions-2023/",
"https://www.fantasypros.com/2023/11/nfl-dfs-week-10-stacking-advice-picks-2023-fantasy-football/",
"https://www.fantasypros.com/2023/11/players-to-buy-low-sell-high-trade-advice-2023-fantasy-football/"]

# Scrapes the blogs above
loader = AsyncChromiumLoader(articles)
docs = loader.load()

# Converts HTML to plain text
html2text = Html2TextTransformer()
docs_transformed = html2text.transform_documents(docs)

# Chunk text
text_splitter = CharacterTextSplitter(chunk_size=100,
chunk_overlap=0)
chunked_documents = text_splitter.split_documents(docs_transformed)

# Load chunked documents into the FAISS index
db = FAISS.from_documents(chunked_documents,
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2'))


# Connect query to FAISS index using a retriever
retriever = db.as_retriever(
search_type="similarity",
search_kwargs={'k': 4}
)

And just like that, we have a vector database set up. Easy enough but let’s dive into what’s actually happening here.

text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)

In our database, we can’t dump the entire document as a whole, we need to break it into pieces so that we can search and find which part of the documents are relevant. This becomes much more critical as we process thousands of articles. This LangChain class allows you to define how large of chunks we want to index. In this case, each chunk contains up to 100 tokens with no overlap between them.

FAISS.from_documents(documents, 
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')))

This step handles the encoding and indexing of documents that we just chunked into the FAISS index. In our case, we are going to encode this using the HuggingFaceEmbeddings() class which gives us access to all of the sentence transformer models (full list here). In the example above, I’m using the sentence-transformers/all-mpnet-base-v2 but there are many others to choose from.

Let’s test out our database by asking it a question and seeing if it can retrieve a relevant chunk:

query = "What did Laporta say?"
docs = db.similarity_search(query)
print(docs[0].page_content)
"It's football, things happen, we certainly work our bodies pretty hard so 
I'm feeling a lot better today," LaPorta said.

That’s a good sign. With that context, the LLM will have a much better chance of accurately answering a similar query.

Now, to get LangChain and our Language Model (LLM) to work together, we need something called a retriever. Think of it as the go-between that helps the vector database and the LLM communicate smoothly.

retriever = db.as_retriever(
search_type="similarity",
search_kwargs={'k': 4}
)

Which means we will take the top 4 results based and the similarity between the query and document. You can find other options here.

Now that we have all of the necessary pieces, we can start creating our “chains”.

Create LLM Chain

First chain we’ll build is the LLM chain for the Mistral LLM.

from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

text_generation_pipeline = transformers.pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=300,
)

prompt_template = """
### [INST]
Instruction: Answer the question based on your
fantasy football knowledge. Here is context to help:

{context}

### QUESTION:
{question}

[/INST]
"""

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

# Create prompt from prompt template
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template,
)

# Create llm chain
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)

Crucially, aside from creating the chain, we update our prompt to allow contexts to be passed in. These are the contexts that our FAISS index will gather.

So now we have an LLMChain that isn’t connected to external datasource. Let’s ask it a question it would only know if it “read” the blog we shared.

llm_chain.invoke({"context":"", 
"question": "Should I pick up Alvin Kamara for my fantasy team?"})

"Whether or not you should pick up Alvin Kamara for your fantasy team
depends on a few factors, such as the specific league rules and roster
requirements, the current performance of Kamara and other players in your
league, and your overall strategy for building your team.

As we’ve all probably seen, it gives us a generic message about only having because it doesn’t have recent information.

Create RAG Chain

Let’s integrate our LLMChain with our FAISS retrieve and put it all together:

query = "Should I pick up Alvin Kamara for my fantasy team?" 

retriever = db.as_retriever()

rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| llm_chain
)

rag_chain.invoke(query)

"Based on the information provided, it seems that Alvin Kamara has been
performing well as a running back in the NFL. He has recorded 36 or more
receiving yards in four of his last five games and has topped 33 or more in
five of six games this season. He is also being peppered with targets
regardless of opponent, which suggests that he is a valuable asset to have
on your fantasy team. Additionally, he has field goal opportunities, which
could provide additional points for"

Let’s take a look at some more examples:

query = "I have Josh Jacobs, should I trade him for Kareem Hunt?"

rag_chain.invoke(query)

"Based on the information provided, it seems like Kareem Hunt could be a
good trade target for Josh Jacobs. Hunt is currently the fourth highest
in the NFL in terms of inside the 5 rushing attempt percentage, and he has
been playing well for the Cleveland Browns since being acquired from the
Kansas City Chiefs. Additionally, the Browns are expected to continue relying
heavily on their running backs due to injuries, which could make Hunt a
valuable asset in your lineup."
query = "What are your thoughts on Ja'Marr Chase going into this week?"

rag_chain.invoke(query)

"It seems that Ja'Marr Chase is currently practicing and may be available
for play in Week 10 of the NFL season. While there is no specific information
about his performance or injury status, it is worth noting that he has been
mentioned as a potential sleeper pick and stacking advice for DFS players.
Additionally, there have been several articles discussing his potential
impact on the Bengals offense and his role in their success. Overall, while
there may be some uncertainty around his performance, it appears that Ja'Marr
Chase could be a valuable addition to any fantasy football team looking for
a reliable wide receiver option."

This might be my favorite example:

query = "Should I trade Saquon Barkley? What are some alternatives."

rag_chain.invoke(query)

"It seems like there are mixed opinions on whether or not to sell
Saquon Barkley. Some people are recommending selling him due to his history
of injuries and the fact that he plays in an awful offense. Others are
suggesting holding onto him because of his volume and potential for upside.

If you decide to sell Saquon Barkley, one recommendation for an alternative
could be RB Derrick Henry. Henry has been performing well this season and
could provide a more reliable source of points for your team. However, keep
in mind that trading for an RB can be risky, especially if you're giving up
valuable assets in return. It's important to carefully consider your options
before making any trades."

This gives me honest and accurate advice given what has transpired over the last few weeks of the season. For someone not keeping up with the minute details, this sort of input would be very useful!

This shows you how effective an LLM can be after we add just a handful of articles. here’s a quick recap on what is happening here:

  1. PromptTemplate Creation: We initiated the process by creating a PromptTemplate. This template requires two inputs: a context and a question. The context provides background information relevant to the question, while the question is what we want our LLM to answer.
  2. Chain Creation: Next, we created a chain. This chain is a sequence of operations that allows us to invoke a query.
  3. RunnablePassthrough Usage: The query is then passed along using RunnablePassthrough(). This function is a part of LangChain’s API and is used to pass the query to the next step in the chain.
  4. Retriever Invocation: The query is also passed into the retriever. The retriever queries our FAISS index, a database designed for efficient similarity search and clustering of dense vectors, and retrieves the relevant context.
  5. Context Integration: The retrieved context is then integrated into our prompt. This step is crucial as it provides the necessary background information that aids the LLM in generating a more accurate and context-aware response.
  6. LLM Invocation: Finally, the enriched prompt is passed into the LLM. In this demonstration, we used a quantized Mistral-7B model, which is a powerful language model capable of generating high-quality text.

In this example, we just indexed a few articles but what if we had a pipeline that is consistently loading a vector database with external sources that you care about. And then your LLM has access to this ever-growing knowledge base. This immediately makes your LLM that much more useful as a daily tool.

Next, it’s time for us to update our RAG to allow the end user to conversationally interact with their vector database. See how in Part 2 of my RAG exploration!

Let me know if you have any questions or suggestions! Please reach out to me via LinkedIn!

Here is the end to end code. I’ve also included this in a notebook.

!pip install -q -U torch datasets transformers tensorflow langchain playwright html2text sentence_transformers faiss-cpu
!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 trl==0.4.7

import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig
pipeline
)

from transformers import BitsAndBytesConfig

from langchain.text_splitter import CharacterTextSplitter
from langchain.document_transformers import Html2TextTransformer
from langchain.document_loaders import AsyncChromiumLoader

from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.llms import HuggingFacePipeline
from langchain.chains import LLMChain

import nest_asyncio
#################################################################
# Tokenizer
#################################################################

model_name='mistralai/Mistral-7B-Instruct-v0.1'

model_config = transformers.AutoConfig.from_pretrained(
model_name,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

#################################################################
# bitsandbytes parameters
#################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

#################################################################
# Set up quantization config
#################################################################
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16: accelerate training with bf16=True")
print("=" * 80)

#################################################################
# Load pre-trained config
#################################################################
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
)


def print_number_of_trainable_model_parameters(model):
trainable_model_params = 0
all_model_params = 0
for _, param in model.named_parameters():
all_model_params += param.numel()
if param.requires_grad:
trainable_model_params += param.numel()
return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(model))

text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=1000,
)

mistral_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

!playwright install
!playwright install-deps

import nest_asyncio
nest_asyncio.apply()

# Articles to index
articles = ["https://www.fantasypros.com/2023/11/rival-fantasy-nfl-week-10/",
"https://www.fantasypros.com/2023/11/5-stats-to-know-before-setting-your-fantasy-lineup-week-10/",
"https://www.fantasypros.com/2023/11/nfl-week-10-sleeper-picks-player-predictions-2023/",
"https://www.fantasypros.com/2023/11/nfl-dfs-week-10-stacking-advice-picks-2023-fantasy-football/",
"https://www.fantasypros.com/2023/11/players-to-buy-low-sell-high-trade-advice-2023-fantasy-football/"]

# Scrapes the blogs above
loader = AsyncChromiumLoader(articles)
docs = loader.load()

# Converts HTML to plain text
html2text = Html2TextTransformer()
docs_transformed = html2text.transform_documents(docs)

# Chunk text
text_splitter = CharacterTextSplitter(chunk_size=100,
chunk_overlap=0)
chunked_documents = text_splitter.split_documents(docs_transformed)

# Load chunked documents into the FAISS index
db = FAISS.from_documents(chunked_documents,
HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2'))

retriever = db.as_retriever()

# Create prompt template
prompt_template = """
### [INST] Instruction: Answer the question based on your fantasy football knowledge. Here is context to help:

{context}

### QUESTION:
{question} [/INST]
"""

# Create prompt from prompt template
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template,
)

# Create llm chain
llm_chain = LLMChain(llm=mistral_llm, prompt=prompt)

rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| llm_chain
)

rag_chain.invoke("Should I start Gibbs next week for fantasy?")

Enjoy!

--

--