How to build a KD Tree in Python to support applications in Vector Databases and Deep Learning

Saptarshi Chaudhuri
6 min readFeb 25, 2024
Image by Copilot (Microsoft + Open AI collaboration) used with permission

Vector Databases have become increasingly essential to building LLM applications. In this educational blog series I implement key vector store/retrieval algorithms from scratch to improve your understanding of vector databases.

No Prerequisite knowledge needed except basic Python

This article is self-contained and requires no prerequisites except basic python knowledge and a cursory high-level knowledge of Binary Trees.

However, if you like the topic of vector databases then I recommend to also check out my first 2 blogs in this series: Introduction to Vector Databases and Dictionary based vector databases

In this blog we will implement vector store and retrieval using KD Trees.

Goal of this article — Why learn about vector databases?

Vector embeddings can codify the similarities and differences between real world objects in a numerical form, which makes them essential for deep learning.

These vector embeddings (nothing but large array of numbers) are stored in special databases called vector databases to speed up Deep Learning applications.

So, if you are a budding enthusiast or pro in deep learning, chances are high you will need to know vector databases.

I am both an enthusiast and a practitioner in deep learning. I learn best by coding things from scratch. My goal in this article is to share my learning experience of coding vector store and search algorithms with all of you.

What is a KD Tree and why is it useful for Vector store & search

KD Trees are a kind of data structure used by vector databases to store vectors (array of numbers). This allows for efficient nearest neighbor search and retrieval of vectors which is often essential for deep learning model based applications.

For instance, let’s say you have a dataset of people’s name, age and salary. A K-D tree may decide to store the data in the following manner (source: Wikipedia)

From Wikipedia

How is a K-D tree constructed?

Construction is very intuitive and easy:

Look at the above example — there are 3 pieces of information: name, age and salary. In KD tree parlance these are called axes.

Each level of the tree will cycle through the different axes — in the example above, root level is “name” followed by “age” as next level, then “salary”. Then it cycles back to “name”, “age” and “salary” and continues recursively until all datapoints in the dataset are stored.

Also, while constructing each level, you only look at the dataset which falls within the subset of the parent node. For instance, in the picture below, you will be looking at all datapoints whose salary is less than $80K

Node insertion: You select the axis, choose the median point along this axis, all points less than the median point ends up in the left while greater end up in the right subtree.

In the above example, first you determine the level — “salary” so you pick the median salary point which is $80K.

Then you identify the axis for the next level — “name”.

Your left node is then the data point with median name among all the points with less than $80K salary while your right node is the median name with greater than $80K salary.

This process continues recursively for each level and node.

Learn through coding — class and method definitions

Enough talk, let’s code and things will get much clearer. But before, kudos to “https://github.com/Vectorized/Python-KD-Tree/” for developing a simple, readable but executable code base on KD Trees.

Here will be our basic class and function definitions we will use with explanatory comments

Class and Build methods

Let’s add the necessary code to the class and Build methods:

The code in the “init” is self-explanatory.

The purpose of the “build” method is to create the tree for the first time given a set of initial points. As mentioned above in the example, it constructs the tree recursively by

  • first selecting the axis of the vector (essentially the index within the vector),
  • second choosing the median of the dataset along that axis,
  • finally creating a left and right subtree by partitioning the dataset into less and greater than the median.

Insert a new point

Whenever inserting a new point into the KD Tree, we need to traverse the tree to identify the right level and parent node for this new point. This is done recursively:

Start at the root: Begin at the root node of the tree.

Compare and move: Compare the point to be inserted with the current node based on the distance of the point from that node along the splitting dimension ‘i’. If the point’s value in the splitting dimension is less than that of the current node, move to the left child; otherwise, move to the right child.

Repeat until a null node is found: Repeat the comparison and movement step until a null node is reached. This is where the new point will be inserted.

Get K-nearest neighbors

Getting the k nearest nodes from a given query vector (or point) is the main purpose of the KD Tree. It is done in the following steps:

1. Start at the root ‘node’ and with an empty heap

2. Measure the distance between the root node and the query point

3. Also measure the distance between the root and the query point along the axis (this will determine if we look for the nearest neighbor in the left or right subtree)

4. If the heap doesn’t have k elements, then insert the root node into the heap — if it has k elements then, between the current heap max and the root node, keep the element closer to the query point.

5. Cycle the axes — we will now move to searching the next level of the KD Tree

6. Depending upon if the axis distance (from step 3) is less, move to the left or else right subtree and repeat the above steps recursively until you have k-nearest neighbors and reach leaf node.

7. Finally, at any point of the search, if you encounter two equidistant points, you can simply keep one. But the logic to break ties is purely application dependent.

At each level of the search process you are selecting either the left or the right subtree to continue. Hence the time complexity of Get_Knn is approximately O(log n) where “n” is the number of nodes in the tree. This speed up in search is the main benefit of using KD Trees for vector databases.

Simple walk

Finally, a nifty piece of code to print your KD Tree

Conclusion — A geometric interpretation

Thank you for making this far. For more interested readers, I will leave you with a geometric interpretation of KD Trees.

KD Trees are part of a larger family of algorithms called “Space Partition Methods” to build vector databases. From a geometric standpoint, you can assume the vector database has K dimensional points spread across a co-ordinate space that has “K” axes.

KD Tree partitions this space into multiple subspaces, the partitions being hyperplanes (or lines if 2D vectors) parallel to axes.

Whenever doing a KNN search, rather than searching the entire space of points, KD Tree lets you search the different partitions while discarding the rest — this property of KD Trees speeds up the search process significantly.

Above diagram is taken from https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf

Follow me on Twitter: @SrishiC and LinkedIn: Saptarshi

--

--

Saptarshi Chaudhuri

Principal Data and Applied Science Manager at Microsoft Cloud & AI division; Building Data Science applications in the Cybersecurity space