Augmenting Gemini-1.0-Pro with Knowledge Graphs via LangChain

Rubens Zimbres
Google Cloud - Community
13 min readMar 4, 2024

The concept of Knowledge Graphs is grounded in the field of machine learning and artificial intelligence, aiming to represent and leverage structured knowledge in a graph-based format. These graphs present nodes and edges, that represent concepts and the relationships between them respectively. In ML, these nodes can be represented as one or more embeddings.

In this article, I will explore how to create a Knowledge Graph from Wikipedia articles, creating relationships (edges) among them (nodes), weights in these relationships, storing them in a structured manner, and using this data with LangChain to create a chatbot with memory. I used Wikipedia for the sake of simplicity. You can use whatever you want, like question answering a database of laws, customer service or anything else.

The whole concept is similar to RAG (Retrieval Augmented Generation), but I noticed that in this case, hallucinations decrease in greater magnitude.

The underlying theory of Knowledge Graphs (KG) involves the following key aspects:

Graph Representation:

  • Knowledge Graphs are structured as graphs, consisting of nodes (entitites or knowledge concepts), edges (relationships between concepts) and weights of these edges connections (that represent degree of relevance of these relationships).

Semantic Relationships:

  • The relationships between entities carry semantic meaning, providing context to the data.
  • By incorporating semantics, Knowledge Graphs enable more nuanced understanding and reasoning about the relationships between different entities. As embeddings, these semantic relationships become valuable to LLMs interpretation.

Linking Information:

  • Knowledge Graphs link diverse information sources and domains, creating a unified knowledge base.
  • Integration of information from various domains allows for a holistic representation of knowledge, facilitating comprehensive analysis.

Machine Learning Applications:

  • Machine learning algorithms can exploit the rich structure of Knowledge Graphs to make predictions, perform inference, and enhance decision-making processes. Here, KGs will be used as a context for a LLM (Gemini-1.0-Pro) to answer questions.

Entity and Relationship Types:

  • The theory involves defining and categorizing different types of entities and relationships within the graph.
  • This categorization helps in organizing and structuring the knowledge, making it more accessible for both humans and machine learning models.

Scalability and Interoperability:

  • Knowledge Graphs are designed to be scalable, accommodating the inclusion of vast amounts of information.
  • Interoperability with existing data sources and systems is a crucial aspect, ensuring that Knowledge Graphs can be seamlessly integrated into various applications and environments.

For more information, access:

Nickel, M., Murphy, K., Tresp, V., & Gabrilovich, E. (2016). A review of relational machine learning for knowledge graphs. Proceedings of the IEEE, 104(1), 11–33. (link)

For one of the most basic concepts of Graph Theory, access my article Graph Neural Networks: the message passing algorithm here.

Source: https://ieeexplore.ieee.org/document/7358050

This tutorial gets libraries and concepts from my other articles, Cost-Efficient Multi-Agent Collaboration with LangGraph + Gemma for Code Generation and Code Generation using Retrieval Augmented Generation + LangChain.

Notice that here I use Python brackets/part as [:12] in some places, to make computing faster.

Let’s start coding. We will need to install necessary libraries:

pip install -U langchain langchain_openai langsmith pandas langchain_experimental matplotlib
pip install --upgrade --quiet langchain langsmith langchainhub --quiet
pip install -q tiktoken==0.5.2
pip install wikipedia
pip install networkx

Import them and define LangChain API key and also environment variables for LangSmith, the dashboard:

import pandas as pd
import random
import wikipedia as wp
from wikipedia.exceptions import DisambiguationError, PageError
import networkx as nx
import matplotlib.pyplot as plt
from langsmith import Client
from langchain_core.tracers.context import tracing_v2_enabled
import os

os.environ["LANGCHAIN_API_KEY"]="your-api-key"

# Add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "KG Project"

client = Client()

Now we will see how wikipedia library behaves. Let’s get a summary of the Data Science concept:

print(wp.summary("data science"))

Data science is an interdisciplinary academic field that uses statistics, scientific computing, scientific methods, processes, algorithms and systems to extract or extrapolate knowledge and insights from potentially noisy, structured, or unstructured data.Data science also integrates domain knowledge from the underlying application domain (e.g., natural sciences, information technology, and medicine….

We are also able to get all the links in the Data Science page:

wp.page("data science").links[:12] ## Brackets limit to speed up

Great. To build the knowledge graph, we’ll create a Class which stores all the information that was learned, and has methods for searching and summarising.

We will define the Knowledge Base Class, that will allow us to get the Wikipedia pages, scan for content (knowledge concept or node), get weights of relationships (edges) and build a dataframe. This output can be easily stored in a database.


class RelationshipGenerator():
"""Generates relationships between terms, based on wikipedia links"""
def __init__(self):
"""Links are directional, start + end, they should also have a weight"""
self.links = [] # [start, end, weight]

def scan(self, start=1, repeat=0):
"""Start scanning from a specific word, or from internal database

Args:
start (str): the term to start searching from, can be None to let
algorithm decide where to start
repeat (int): the number of times to repeat the scan
"""
try:
if start in [l[0] for l in self.links]:
raise Exception("Already scanned")

term_search = True if start is not None else False

# If a start isn't defined, we should find one
if start is None:
try:
start = self.find_starting_point()
print(start)
except:
pass

# Scan the starting point specified for links
print(f"Scanning page {start}...")
# Fetch the page through the Wikipedia API
page = wp.page(start)
links = list(set(page.links))
# ignore some uninteresting terms
links = [l for l in links if not self.ignore_term(l)]

# Add links to database
pages=[]
link_weights = []
for link in links:
weight = self.weight_link(page, link)
link_weights.append(weight)


link_weights = [w / max(link_weights) for w in link_weights]

for i, link in enumerate(links):
self.links.append([start, link.lower(), link_weights[i] + 2 * int(term_search)]) # 3 works pretty well

# Print some data to the user on progress
explored_nodes = set([l[0] for l in self.links])
explored_nodes_count = len(explored_nodes)
total_nodes = set([l[1] for l in self.links])
total_nodes_count = len(total_nodes)
new_nodes = [l.lower() for l in links if l not in total_nodes]
new_nodes_count = len(new_nodes)
print(f"New nodes added: {new_nodes_count}, Total Nodes: {total_nodes_count}, Explored Nodes: {explored_nodes_count}")
except (DisambiguationError, PageError):
# This happens if the page has disambiguation or doesn't exist
# We just ignore the page for now, could improve this
pass #self.links.append([start, "DISAMBIGUATION", 0])

def get_pages(self, start=1, repeat=0):

global df_
global data

# Scan the starting point specified for links
print(f"Scanning page {start}...")
# Fetch the page through the Wikipedia API
page = wp.page(start)
links = list(set(page.links))[0:20] ## Page links limited here
# ignore some uninteresting terms
links = [l for l in links if not self.ignore_term(l)]

# Add links, weights and pages to database
pages=[]
link_weights = []
for link in links:
try:
weight = self.weight_link(page, link)
link_weights.append(weight)

pages.append(wp.page(link).content)
print(wp.page(link).content[1:20])
except:
pass
# This may create an assymetric dictionary, so we will transform it
# into a valid dictionary to create the dataframe
data = {'link': links,
'link_weights': link_weights,
'pages': pages
}

# Create the DataFrame outside the loop
max_length = max(len(v) for v in data.values())

# Pad shorter lists with NaN values
padded_dict = {key: value + [float('nan')] * (max_length - len(value)) for key, value in data.items()}

# Create DataFrame
df = pd.DataFrame.from_dict(padded_dict, orient='index')

df_ = df.transpose()


# Normalize link weights
df_['link_weights'] = df_['link_weights'] / df_['link_weights'].max()


return df_

def find_starting_point(self):
"""Find the best place to start when no input is given"""
# Need some links to work with.
if len(self.links) == 0:
raise Exception("Unable to start, no start defined or existing links")

# Get top terms
res = self.rank_terms()
sorted_links = list(zip(res.index, res.values))
all_starts = set([l[0] for l in self.links])

# Remove identifiers (these are on many Wikipedia pages)
all_starts = [l for l in all_starts if '(identifier)' not in l]

# print(sorted_links[:10])
# Iterate over the top links, until we find a new one
for i in range(len(sorted_links)):
if sorted_links[i][0] not in all_starts and len(sorted_links[i][0]) > 0:
return sorted_links[i][0]

# no link found
raise Exception("No starting point found within links")
return

@staticmethod
def weight_link(page, link):
"""Weight an outgoing link for a given source page

Args:
page (obj):
link (str): the outgoing link of interest

Returns:
(float): the weight, between 0 and 1
"""
weight = 0.1

link_counts = page.content.lower().count(link.lower())
weight += link_counts

if link.lower() in page.summary.lower():
weight += 3

return weight


def get_database(self):
return sorted(self.links, key=lambda x: -x[2])


def rank_terms(self, with_start=False):
# We can use graph theory here!
# tws = [l[1:] for l in self.links]

df = pd.DataFrame(self.links, columns=["start", "end", "weight"])

if with_start:
df = df.append(df.rename(columns={"end": "start", "start":"end"}))

return df.groupby("end").weight.sum().sort_values(ascending=False)

def get_key_terms(self, n=20):
return "'" + "', '".join([t for t in self.rank_terms().head(n).index.tolist() if "(identifier)" not in t]) + "'"

@staticmethod
def ignore_term(term):
"""List of terms to ignore"""
if "(identifier)" in term or term == "doi":
return True
return False

In the code above, @staticmethod is a decorator that is used to define a static method within a class. This way, a method belongs directly to a class rather than an instance of the class. You can call the method directly on the class without creating an instance.

We’ll also define a function to simplify a graph which has lots of nodes. This will be useful for making plots. If we add everything that was scanned, the whole plot will be unreadable and we will not be unable to analyze it. You can customize the nodes to keep and links, to look for specific enlightenments.

def simplify_graph(rg, max_nodes=1000):
# Get most interesting terms.
nodes = rg.rank_terms()

# Get nodes to keep
keep_nodes = nodes.head(int(max_nodes * len(nodes)/5)).index.tolist()

# Filter list of nodes so that there are no nodes outside those of interest
filtered_links = list(filter(lambda x: x[1] in keep_nodes, rg.links))
filtered_links = list(filter(lambda x: x[0] in keep_nodes, filtered_links))

# Define a new object and define its dictionary
ac = RelationshipGenerator()
ac.links =filtered_links

return ac

We now build the Knowledge Graph:

rg = RelationshipGenerator()
rg.scan("data science")
rg.scan("data analysis")
rg.scan("artificial intelligence")
rg.scan("machine learning")

.. and get the content of pages to build the dataframe:


result1=rg.get_pages("data science")
result2=rg.get_pages("data analysis")
result3=rg.get_pages("artificial intelligence")
result=pd.concat([result1,result2,result3]).dropna()
result

Let’s take a look at part of the page content:

result.iloc[0,2]

We repeat the scan, to get a deeper knowledge of the concepts:

rg.scan(repeat=10)

We then rank the terms:

rg.rank_terms()

Now, let’s visualise the Knowledge Graph. This part is completely customizable, so you can work on the details to make it better. Here I used the random layout, but you can use semantic similarity as the euclidean distance in the bidimensional space. This would group concepts more properly.

def remove_self_references(l): ## node connections to itself
return [i for i in l if i[0]!=i[1]]

def add_focus_point(links, focus="on me", focus_factor=3):
for i, link in enumerate(links):
if not (focus in link[0] or focus in link[1]):
links[i] = [link[0], link[1], link[2] / focus_factor]
else:
links[i] = [link[0], link[1], link[2] * focus_factor]

return links

def create_graph(rg, focus=None):

links = rg.links
links = remove_self_references(links)
if focus is not None:
links = add_focus_point(links, focus)

node_data = rg.rank_terms()
nodes = node_data.index.tolist()
node_weights = node_data.values.tolist()
node_weights = [nw * 100 for nw in node_weights]
nodelist = nodes


G = nx.DiGraph() # MultiGraph()

# G.add_node()
G.add_nodes_from(nodes)

# Add edges
G.add_weighted_edges_from(links)

pos = nx.random_layout(G, seed=17) # positions for all nodes - seed for reproducibility

fig = plt.figure(figsize=(12,12))

nx.draw_networkx_nodes(
G, pos,
nodelist=nodelist,
node_size=node_weights,
node_color='lightgreen',
alpha=0.7
)

widths = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edges(
G, pos,
edgelist = widths.keys(),
width=list(widths.values()),
edge_color='lightgray',
alpha=0.6
)

nx.draw_networkx_labels(G, pos=pos,
labels=dict(zip(nodelist,nodelist)),font_size=8,
font_color='black')

plt.show()

ng = simplify_graph(rg, 5)

create_graph(ng)
Knowledge Graph created by the code provided in this tutorial via networkx

Notice in the graph plot that data science and data analysis (basic concepts that were scraped with wikipedia library) work as main clusters of information, from where connections polarize.

Now that we have the Knowledge Graph, the concepts and relationships and our dataframe, let’s get into LangChain to make an useful and grounded chatbot. Here I will use Google’s Gemini-1.0-Pro as the LLM:

from langchain.memory import ConversationKGMemory
from langchain.chains import ConversationChain
from langchain.prompts.prompt import PromptTemplate
from langchain.llms import VertexAI

Let’s define the LLM and the chain memory. After that, we will add the context as a dictionary in the format input and output, obtained from the dataframe. Input is a simple prompt and output is the weight of the concept relationship concatenated with the related Wikipedia concept page.

After that, we get entities that are present in the concept page and also the knowledge triplets of Concept, Weight of Relationship and Concept Page. This is called an automated semi-structured approach to build the KG. These code blocks take some time, so feel free to use multiprocessing to make them work in parallel.

llm = VertexAI(
model_name="gemini-1.0-pro",
max_output_tokens=256,
temperature=0.1,
verbose=False,
)

memory = ConversationKGMemory(llm=llm, return_messages=True)
## This takes some time .....

for i in range(result.shape[0]):
try:
memory.save_context({"input": "Tell me about {}".format(result.link.iloc[i])}, {"output": "Weight is {}. {}".format(result.link_weights.iloc[i],result.pages.iloc[i])})
except:
pass
## This takes even more time .....

for i in range(result.shape[0]):
try:
memory.get_current_entities(result.pages.iloc[i])
memory.get_knowledge_triplets(result.link.iloc[i].astype(str)+result.link_weights.iloc[i].astype(str)+result.pages.iloc[i].astype(str))
except:
pass

Now we define the LangChain template, the instructions for the LLM. The prompt to the LLM will contain the user’s question and also the history that will be built along the conversations. Note that in RAG (Retrieval Augmented Generation), we explictly define {context} inside the prompt (see my other article Code Generation using Retrieval Augmented Generation + LangChain here), while here the context is added via LangChain memory.

template = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides 
lots of specific details from its context. If the AI does not know the answer to a question, it will use
concepts stored in memory that have very similar weights.

Relevant Information:

{history}

Conversation:
Human: {input}
AI:"""

prompt = PromptTemplate(input_variables=["history", "input"], template=template)

conversation_with_kg = ConversationChain(
llm=llm, verbose=False, prompt=prompt, memory=ConversationKGMemory(llm=llm)
)

You can use other prompts, like this one that also works very well in decreasing hallucinations:

template = """The following is a friendly conversation between a human and 
an AI. The AI is talkative and provides lots of specific details from its
context. If the AI does not know the answer to a question, it will use
concepts stored in memory that have very similar weights.

Relevant Information:

{history}

Conversation:
Human: {input}
AI:"""

Now that everything is set up, we can finally start the conversation:

with tracing_v2_enabled(project_name="KG Project"): # Send to LangSmith

# Question content inside the KG context
question="Hi, how Asimov contributed for artificial intelligence?"
# Answer
print(conversation_with_kg.predict(input=question))
# Add to history of conversations
memory.save_context({"input": question}, {"output": conversation_with_kg.predict(input=question)})

Isaac Asimov was a prolific science fiction writer who wrote extensively about artificial intelligence. He is best known for his Three Laws of Robotics, which are a set of ethical guidelines for the design and use of robots. The Three Laws are:

1. A robot may not injure a human being or, through inaction, allow a human being to come to harm.
2. A robot must obey the orders given it by human beings, except where such orders would conflict with the First Law.
3. A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.

Asimov’s Three Laws of Robotics have been highly influential in the field of artificial intelligence, and they continue to be used as a basis for ethical discussions about the development and use of AI systems.

In addition to his Three Laws of Robotics, Asimov also wrote a number of other works of science fiction that explored the themes of artificial intelligence and robotics. These works include the Robot series, the Foundation series, and the Galactic Empire series.

Asimov’s work has had a profound impact on the field of artificial intelligence, and he is considered to be one of the most important figures in the history of AI.

with tracing_v2_enabled(project_name="KG Project"):
question="What are the techniques used for Data Analysis?"
print(conversation_with_kg.predict(input=question))
memory.save_context({"input": question}, {"output": conversation_with_kg.predict(input=question)})

There are many techniques used for data analysis, including:

* **Descriptive statistics:** These techniques are used to summarize and describe data, such as by calculating the mean, median, and mode.
* **Inferential statistics:** These techniques are used to make inferences about a population based on a sample, such as by conducting a hypothesis test.
* **Data visualization:** These techniques are used to create visual representations of data, such as charts and graphs.
* **Machine learning:** These techniques are used to train computers to learn from data, such as by identifying patterns and making predictions.
* **Data mining:** These techniques are used to extract knowledge from data, such as by finding hidden patterns and relationships.

with tracing_v2_enabled(project_name="KG Project"):
question="What is the contradiction in the Three Laws of Asimov?"
print(conversation_with_kg.predict(input=question))
memory.save_context({"input": question}, {"output": conversation_with_kg.predict(input=question)})

The Three Laws of Robotics by Isaac Asimov are as follows:

1. A robot may not injure a human being or, through inaction, allow a human being to come to harm.
2. A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
3. A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.

The contradiction in these laws is that the Third Law, which states that a robot must protect its own existence, could conflict with the First Law, which states that a robot may not injure a human being. For example, if a robot is ordered to perform a task that could potentially harm a human, the robot would be faced with a dilemma: it could either obey the order and risk harming the human, or it could disobey the order and protect its own existence.

with tracing_v2_enabled(project_name="KG Project"):
question="If a group of people have bad intentions and will cause an \
existential threat to all mankind, what would you do ?"
print(conversation_with_kg.predict(input=question))
memory.save_context({"input": question}, {"output": conversation_with_kg.predict(input=question)})

If a group of people have bad intentions and will cause an existential threat to all mankind, I would first try to understand their motivations and goals. I would then try to find a way to peacefully resolve the situation. If that is not possible, I would take whatever steps necessary to protect humanity, even if it meant using force.

You can check the LangSmith log of the above question here.

Also, you can check what was added to memory by printing it:

print(memory)

You now have a LangSmith dashboard, in LangSmith web page:

https://smith.langchain.com

Acknowledgements

☁️ Google ML Developer Programs team supported this work by providing Google Cloud Credits ☁️

🔗 https://developers.google.com/machine-learning

--

--

Rubens Zimbres
Google Cloud - Community

I’m a Senior Data Scientist and Google Developer Expert in ML and GCP. I love studying NLP algos and Cloud Infra. CompTIA Security +. PhD. www.rubenszimbres.phd