02 Gen AI on Databricks: Building Multi-stage Reasoning Chain in Databricks
In this demo, we will explore how to build a multi-stage reasoning system using Databricks’ features and Lang Chain. The demo begins by introducing various components commonly used in multi-stage systems and progresses to building a complete chain for advanced reasoning tasks.
The primary goal is to show how to create chains that handle specific tasks, such as answering user queries, searching external resources, and recommending videos.
Learning Objectives
By the end of this demo, you will:
- Identify that LangChain can include stages/tasks that are not LLM-based, such as retrieving external data or using specialized tools.
- Create basic LLM chains to connect prompts and LLMs.
- Use tools to complete various tasks in a multi-stage system.
- Construct sequential chains of multiple LLMChains to perform multi-stage reasoning analysis.
Requirements
To execute this demo, use the following Databricks runtime:
15.4.x-cpu-ml-scala2.12
This runtime ensures compatibility with the required libraries and provides the necessary features for handling advanced AI workflows.
Additionally, ensure the following libraries are installed:
%pip install -U -qq databricks-sdk databricks-vectorsearch langchain-databricks langchain==0.3.7 langchain-community==0.3.7 youtube_search Wikipedia grandalf
Restart the Python environment after installation:
dbutils.library.restartPython()
%run ../Includes/Classroom-Setup-02
Before running the demo, execute the classroom setup script to define the configuration variables necessary for the lesson.
Setting Up
Throughout this demo, the object DA
provided by Databricks Academy will be used. It includes variables such as your username, catalog name, schema name, working directory, and dataset locations. Run the code below to verify your setup:
print(f"Username: {DA.username}")
print(f"Catalog Name: {DA.catalog_name}")
print(f"Schema Name: {DA.schema_name}")
print(f"Working Directory: {DA.paths.working_dir}")
print(f"Dataset Location: {DA.paths.datasets}")
Using LLMs and Prompts Without External Libraries
Before diving into LangChain, let’s demonstrate a simple interaction with a foundational model using databricks-sdk
:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage
w = WorkspaceClient()
genre = "romance"
actor = "Brad Pitt"
prompt = f"Tell me about a {genre} movie which {actor} is one of the actors."
While this approach works for simple queries, using LangChain can streamline the construction of complex, multi-stage systems by providing modular components and reusable interfaces.
LangChain Basics
LangChain provides essential components for building multi-stage reasoning systems, including prompts, LLMs, retrievers, and tools. Each of these components plays a unique role: prompts define input queries, LLMs perform reasoning, retrievers fetch relevant data, and tools execute specific tasks.
Prompts
Prompts define the structure of your queries to the model. For example:
from langchain.prompts import PromptTemplate
prompt_template = PromptTemplate.from_template("Tell me about a {genre} movie which {actor} is one of the actors.")
formatted_prompt = prompt_template.format(genre="romance", actor="Brad Pitt")
LLMs
LLMs serve as the reasoning core of your system. Here’s an example using Meta’s Llama-3:
from langchain_databricks import ChatDatabricks
llm_llama = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct", max_tokens=500)
for chunk in llm_llama.stream("Who is Brad Pitt?"):
print(chunk.content, end="\n", flush=True)
Retrievers
Retrievers fetch relevant documents or data using similarity search. Example:
from langchain_community.retrievers import WikipediaRetriever
retriever = WikipediaRetriever()
docs = retriever.invoke(input="Brad Pitt")
print(docs[0])
Tools
Tools are functions that execute specific tasks in a chain. Example:
from langchain_community.tools import YouTubeSearchTool
tool = YouTubeSearchTool()
video_links = tool.run("Brad Pitt movie trailer")
print(tool.description)
Building the Multi-stage Chain
Task 1: Create a Vector Store
The dataset used in this demo is preprocessed and stored as Delta tables. Run the classroom setup script to prepare the vector store index.
vs_endpoint_prefix = "vs_endpoint_"
vs_endpoint_name = vs_endpoint_prefix + str(get_fixed_integer(DA.unique_name("_")))
print(f"Assigned Vector Search endpoint name: {vs_endpoint_name}.")
source_table_fullname = f"{DA.catalog_name}.{DA.schema_name}.dais_text"
vs_index_table_fullname = f"{DA.catalog_name}.{DA.schema_name}.dais_embeddings"
create_vs_index(vs_endpoint_name, vs_index_table_fullname, source_table_fullname, "Title")
Task 2: Build the First Chain
The first chain retrieves relevant video titles based on user queries.
from langchain.prompts import PromptTemplate
from langchain.chains import create_retrieval_chain
prompt_template_1 = PromptTemplate.from_template(
"""Construct a search query for YouTube based on the titles below. Include "DAIS 2023" in the query."""
)
retriever_chain = create_retrieval_chain(
retriever=retriever, llm=llm_llama, prompt=prompt_template_1
)
Task 3: Build the Second Chain
The second chain fetches video links using the YouTube search tool.
from langchain_community.tools import YouTubeSearchTool
from langchain_core.runnables import RunnableLambda
def get_videos(input):
tool_yt = YouTubeSearchTool()
return tool_yt.run(input)
chain_youtube = RunnableLambda(get_videos)
Task 4: Build the Final Chain
The third chain recommends videos based on user queries.
prompt_template_3 = PromptTemplate.from_template(
"""You are a Databricks expert. Answer the question below and recommend related videos:
Question: {input}
Recommended Videos:
{videos}"""
)
multi_chain = (
{"input": RunnablePassthrough(), "videos": (retriever_chain | chain_youtube)}
| llm_llama
)
Task 5: Save the Chain to Unity Catalog
Finally, save the constructed chain to the Unity Catalog for future use. Storing the chain in Unity Catalog facilitates collaboration, versioning, and seamless reuse across projects:
from mlflow.models import infer_signature
import mlflow
mlflow.set_registry_uri("databricks-uc")
model_name = f"{DA.catalog_name}.{DA.schema_name}.multi_stage_demo"
with mlflow.start_run(run_name="multi_stage_demo") as run:
signature = infer_signature(query, response)
mlflow.langchain.log_model(multi_chain, artifact_path="model", signature=signature, registered_model_name=model_name)
Conclusion
In this demo, you’ve learned how to build a multi-stage reasoning system using Databricks and LangChain. The system combines multiple components, such as LLMs, retrievers, and tools, to create a sophisticated AI pipeline capable of complex tasks like video recommendations. By integrating these chains, you can create highly customizable AI workflows tailored to specific use cases.