02 Gen AI on Databricks: Building Multi-stage Reasoning Chain in Databricks

THE BRICK LEARNING

--

In this demo, we will explore how to build a multi-stage reasoning system using Databricks’ features and Lang Chain. The demo begins by introducing various components commonly used in multi-stage systems and progresses to building a complete chain for advanced reasoning tasks.

The primary goal is to show how to create chains that handle specific tasks, such as answering user queries, searching external resources, and recommending videos.

Learning Objectives

By the end of this demo, you will:

  1. Identify that LangChain can include stages/tasks that are not LLM-based, such as retrieving external data or using specialized tools.
  2. Create basic LLM chains to connect prompts and LLMs.
  3. Use tools to complete various tasks in a multi-stage system.
  4. Construct sequential chains of multiple LLMChains to perform multi-stage reasoning analysis.

Requirements

To execute this demo, use the following Databricks runtime:

15.4.x-cpu-ml-scala2.12

This runtime ensures compatibility with the required libraries and provides the necessary features for handling advanced AI workflows.

Additionally, ensure the following libraries are installed:

%pip install -U -qq databricks-sdk databricks-vectorsearch langchain-databricks langchain==0.3.7 langchain-community==0.3.7 youtube_search Wikipedia grandalf

Restart the Python environment after installation:

dbutils.library.restartPython()
%run ../Includes/Classroom-Setup-02

Before running the demo, execute the classroom setup script to define the configuration variables necessary for the lesson.

Setting Up

Throughout this demo, the object DA provided by Databricks Academy will be used. It includes variables such as your username, catalog name, schema name, working directory, and dataset locations. Run the code below to verify your setup:

print(f"Username:          {DA.username}")
print(f"Catalog Name: {DA.catalog_name}")
print(f"Schema Name: {DA.schema_name}")
print(f"Working Directory: {DA.paths.working_dir}")
print(f"Dataset Location: {DA.paths.datasets}")

Using LLMs and Prompts Without External Libraries

Before diving into LangChain, let’s demonstrate a simple interaction with a foundational model using databricks-sdk:

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage

w = WorkspaceClient()

genre = "romance"
actor = "Brad Pitt"
prompt = f"Tell me about a {genre} movie which {actor} is one of the actors."

While this approach works for simple queries, using LangChain can streamline the construction of complex, multi-stage systems by providing modular components and reusable interfaces.

LangChain Basics

LangChain provides essential components for building multi-stage reasoning systems, including prompts, LLMs, retrievers, and tools. Each of these components plays a unique role: prompts define input queries, LLMs perform reasoning, retrievers fetch relevant data, and tools execute specific tasks.

Prompts

Prompts define the structure of your queries to the model. For example:

from langchain.prompts import PromptTemplate

prompt_template = PromptTemplate.from_template("Tell me about a {genre} movie which {actor} is one of the actors.")
formatted_prompt = prompt_template.format(genre="romance", actor="Brad Pitt")

LLMs

LLMs serve as the reasoning core of your system. Here’s an example using Meta’s Llama-3:

from langchain_databricks import ChatDatabricks

llm_llama = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct", max_tokens=500)
for chunk in llm_llama.stream("Who is Brad Pitt?"):
print(chunk.content, end="\n", flush=True)

Retrievers

Retrievers fetch relevant documents or data using similarity search. Example:

from langchain_community.retrievers import WikipediaRetriever

retriever = WikipediaRetriever()
docs = retriever.invoke(input="Brad Pitt")
print(docs[0])

Tools

Tools are functions that execute specific tasks in a chain. Example:

from langchain_community.tools import YouTubeSearchTool

tool = YouTubeSearchTool()
video_links = tool.run("Brad Pitt movie trailer")
print(tool.description)

Building the Multi-stage Chain

Task 1: Create a Vector Store

The dataset used in this demo is preprocessed and stored as Delta tables. Run the classroom setup script to prepare the vector store index.

vs_endpoint_prefix = "vs_endpoint_"
vs_endpoint_name = vs_endpoint_prefix + str(get_fixed_integer(DA.unique_name("_")))
print(f"Assigned Vector Search endpoint name: {vs_endpoint_name}.")

source_table_fullname = f"{DA.catalog_name}.{DA.schema_name}.dais_text"
vs_index_table_fullname = f"{DA.catalog_name}.{DA.schema_name}.dais_embeddings"

create_vs_index(vs_endpoint_name, vs_index_table_fullname, source_table_fullname, "Title")

Task 2: Build the First Chain

The first chain retrieves relevant video titles based on user queries.

from langchain.prompts import PromptTemplate
from langchain.chains import create_retrieval_chain

prompt_template_1 = PromptTemplate.from_template(
"""Construct a search query for YouTube based on the titles below. Include "DAIS 2023" in the query."""
)

retriever_chain = create_retrieval_chain(
retriever=retriever, llm=llm_llama, prompt=prompt_template_1
)

Task 3: Build the Second Chain

The second chain fetches video links using the YouTube search tool.

from langchain_community.tools import YouTubeSearchTool
from langchain_core.runnables import RunnableLambda

def get_videos(input):
tool_yt = YouTubeSearchTool()
return tool_yt.run(input)

chain_youtube = RunnableLambda(get_videos)

Task 4: Build the Final Chain

The third chain recommends videos based on user queries.

prompt_template_3 = PromptTemplate.from_template(
"""You are a Databricks expert. Answer the question below and recommend related videos:

Question: {input}

Recommended Videos:
{videos}"""
)

multi_chain = (
{"input": RunnablePassthrough(), "videos": (retriever_chain | chain_youtube)}
| llm_llama
)

Task 5: Save the Chain to Unity Catalog

Finally, save the constructed chain to the Unity Catalog for future use. Storing the chain in Unity Catalog facilitates collaboration, versioning, and seamless reuse across projects:

from mlflow.models import infer_signature
import mlflow

mlflow.set_registry_uri("databricks-uc")
model_name = f"{DA.catalog_name}.{DA.schema_name}.multi_stage_demo"

with mlflow.start_run(run_name="multi_stage_demo") as run:
signature = infer_signature(query, response)
mlflow.langchain.log_model(multi_chain, artifact_path="model", signature=signature, registered_model_name=model_name)

Conclusion

In this demo, you’ve learned how to build a multi-stage reasoning system using Databricks and LangChain. The system combines multiple components, such as LLMs, retrievers, and tools, to create a sophisticated AI pipeline capable of complex tasks like video recommendations. By integrating these chains, you can create highly customizable AI workflows tailored to specific use cases.

--

--

No responses yet