Approximate Nearest Neighbors (ANN) Algorithm using KD-Trees from Scratch in Python

Pradyumna
8 min readAug 9, 2024

--

Introduction

The Approximate Nearest Neighbor (ANN) algorithm is a powerful technique used to quickly find points in high-dimensional spaces that are close to a given query point. ANN provides a balance between accuracy and speed, making it suitable for large datasets and real-time applications such as recommendation systems, image retrieval, and anomaly detection.

In this blog let’s understand the ANN algorithm and implement it from scratch using KD-Trees in Python

Need for ANN

The traditional K-Nearest Neighbor (KNN) algorithm is simple and intuitive. Given a dataset and a query point, it calculates the distance from the query point to every other point in the dataset and identifies the K closest ones. While this brute-force approach works well for small datasets or low-dimensional spaces, it quickly becomes computationally prohibitive as the size of the dataset or the dimensionality of the space increases. The time complexity of the kNN algorithm is O(n), where n is the number of points in the dataset meaning the computation time grows proportionally with the number of points in the dataset. Such conditions can lead to significant performance bottlenecks.

Therefore, we need a way to speed up the nearest neighbor search, and one such technique is called Approximate Nearest Neighbor (ANN). By its very name, ANN indicates that we are making some approximations in finding the neighbors. The idea is to trade off a bit of accuracy for a significant gain in speed, this trade-off between accuracy and latency is often acceptable in many practical applications, where the exact neighbors are less important than finding neighbors quickly.

ANN Algorithm

ANN algorithms use various techniques such as spatial partitioning, hashing, and dimensionality reduction. These methods help in efficiently narrowing down the search space, thus reducing the number of distance computations required.

Spatial Partitioning

Spatial partitioning is a technique that involves dividing a data space into smaller, more manageable regions, this division allows for efficient searching and querying, especially in high-dimensional spaces where brute-force methods like traditional KNN become computationally expensive.

For simplicity and better visualization, let us take 2D data points to understand the Spatial partitioning. For that, I have generated 1000 random 2D points.

# importing the required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random

# defining the number of random vectors (let us consider 1000 random points)
n_vectors = 1000

# initializing an empty list to store the points
vectors = []

# generating 1000 random 2D points
for i in range(n_embeddings):
# generate a random 2D embedding
x = round(100*random.random(),2) # x co-ord of the 2d embedding (random num between 0 and 100)
y = round(100*random.random(),2) # y co-ord of the 2d embedding (random num between 0 and 100)
vectors.append(np.array([x,y]))

# visualizing the datapoints
fig = plt.figure(figsize=(7,7))
for i in range(n_vectors):
plt.scatter(vectors[i][0],vectors[i][1],c="black",marker="x")
plt.show()
Random 2D vectors

Step 1: Pick 2 random points from the given points

Random points

Step 2: Draw an equidistant hyperplane between these two vectors

Step 3: Now we have two separate regions, perform steps 1 and 2 in these two regions separately

I) Now pick two random points from region 1 (below blue hyperplane)

II) Draw another equidistant hyperplane between these two newly picked random points

III) Now pick two random points from Region 2 (above blue hyperplane)

IV) Now draw another equidistant hyperplane (green) between these two chosen points

Now we have divided the whole 2D space into 4 manageable regions, now recursively perform steps 1 and 2 to divide these 4 regions into N regions until each region atmost has some K points,

Before building the tree let us see some helper functions

# find the equidistant hyperplane between the two given vectors
def hyperplane_equation(v1:np.ndarray,v2:np.ndarray):
"""
returns the vector normal to the hyperplane equidistant from v1 and v2 and the constant term in the equation of the hyperplane
v1 : vector1
v2 : vector2
"""
# finding the normal vector
normal_vector = v2-v1
# finding the midpoint
midpoint = (v1+v2)/2
# finding the const term
const_term = np.dot(normal_vector,midpoint)
return normal_vector,const_term




# checking which side of the hyperplane does the given point(vector)lies
def check_vector_side(normal_vector,constant,vector):
"""
returns which side to the hyperplane does the vector lies
normal_vector : the vector normal to the hyperplane
constant : constant term in the hyperplane
vector : the vector whose side we want to find
"""
# performing the dot product between the normal vector and vector
result = np.dot(normal_vector,vector)
if result<constant:
side="right"
else:
side="left"
return side

Spatial Partitioning using KD Tree

One of the most popular and effective spatial partitioning techniques is using the KD-Tree (short for K-Dimensional Tree), which is widely used in Approximate Nearest Neighbor (ANN) algorithms. Now, let’s try to implement this spatial partitioning technique with KD trees.

The algorithm comprised of 2 phases

1. Building the KD Tree

2. Searching with a KD Tree

Building the KD Tree

Let’s first define the node of a KD tree, in a KD tree each node denotes the d-dimensional space, each node has a hyperplane drawn between two random points in the d-dim space and exactly two children left and right where the points that lie to the left-hand side of the hyperplane will be stored in the left and the points that lie to the right-hand side of the hyperplane will be stored in the right

class Node:
def __init__(self,hyperplane=None,constant=None,values=None):
"""
hyperplane : the equation of the hyperplane equidistant from the two points
constant : the constant term in the hyperplane equation
values : the vectors to separate based on the hyperplane
"""
self.hyperplane=hyperplane
self.constant = constant
self.values=values
# points that lie to the left hand side of the hyperplane
self.left=None
# points that lie to the right hand side of the hyperplane
self.right=None

let’s see the code to create a KD tree by recursively dividing the

# building the tree
def build_tree(vectors):
"""
builds the tree using the given vectors to find approximate nearest neighbors and returns the root node of the Tree
vectors: list of all vectors
"""
idx1 = 0
idx2 = 0
while idx1==idx2:
# pick any two random numbers withing the number of vectors range
idx1 = random.randint(0,len(vectors)-1)
idx2 = random.randint(0,len(vectors)-1)
# pick any two random vectors from the list of vectors
first_vector = vectors[idx1]
second_vector = vectors[idx2]
# find the equidistant hyperplane between first vector and second vector
hyperplane,constant = hyperplane_equation(first_vector,second_vector)
# classfiy all vectors in the vectors list are on left or right with respect to hyperplane
left_nodes=[]
right_nodes=[]
for vector in vectors:
side = check_vector_side(hyperplane,constant,vector)
# appending the vectors that lie to the right side of the hyperplane in right nodes
if side=="right":
right_nodes.append(vector)
elif side=="left":
left_nodes.append(vector)
# building the current tree node
current_node = Node(hyperplane=hyperplane,constant=constant,values=vectors)

# if the size of left node is greater than the min_subset_size (if so we need to split it further)
if len(left_nodes)>min_subset_size:
current_node.left = build_tree(left_nodes)
else:
current_node.left = Node(values=left_nodes)

# if the size of left node is greater than the min_subset_size (if so we need to split it further)
if len(right_nodes)>min_subset_size:
current_node.right = build_tree(right_nodes)
else:
current_node.right = Node(values=right_nodes)

return current_node

Searching with a KD Tree

Let’s write the function that searches the KD Tree for the nearest neighbors

# searching for the nearest neighbors through tree
def search(tree,query_vector):
"""
returns the list of nearest neighbors of the query vector by traversing through the ANNOY tree
tree : ANNOY tree built using build_tree function
query_vector : vector whose nearest neighbors we want to find

leaf node has three condition
1. the size of values attribute of the leaf node is less than the min_subset_size
2. the hyperplane attribute of the leaf node is None (as there is no hyperplane needed to split the node further if the first conditon is satisfied )
3. the constant attribute of the leaf node is None (as there is no hyperplane needed to split the node further if the first conditon is satisfied )
"""
# traversing till we reach leaves
while len(tree.values)>min_subset_size and tree.hyperplane is not None and tree.constant is not None:
# checking the side of the vector sample
side = check_vector_side(tree.hyperplane,tree.constant,query_vector)
if side=="left":
print(f"go to left {len(tree.values)}")
tree = tree.left
elif side=="right":
print(f"go to right {len(tree.values)}")
tree = tree.right
print("The Neightbors are\n",tree.values)
return tree.values

Example

First, we need to build a KD tree using all of our data points, so for that let’s synthesize our own data

# importing the required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random

# defining the number of random vectors (let us consider 1000 random points)
n_vectors = 1000

# initializing an empty list to store the points
vectors = []

# generating 1000 random 2D points
for i in range(n_embeddings):
# generate a random 2D embedding
x = round(100*random.random(),2) # x co-ord of the 2d embedding (random num between 0 and 100)
y = round(100*random.random(),2) # y co-ord of the 2d embedding (random num between 0 and 100)
vectors.append(np.array([x,y]))

Now we need to build the KD Tree

# testing the build tree function
test_tree = build_tree(vectors)

Let us create a query point and search for its nearest neighbors using ANN algorithm

# testing the search function
sample = [50,50]
nearest_neighbors = search(test_tree,sample)
print(f"The Neighbors are \n {nearest_neighbors}")

Let’s visualize the neighbors with the sample point

# visualizing the embeddings
fig = plt.figure(figsize=(7,7))
for i in range(n_vectors):
plt.scatter(vectors[i][0],vectors[i][1],c="black",marker="x")
for j in nearest_neighbors:
plt.scatter(j[0],j[1],c="red",marker="x")
plt.scatter(sample[0],sample[1],c="yellow",marker="x")
plt.show()

Conclusion

In conclusion, Approximate Nearest Neighbor (ANN) algorithms, especially using KD-Trees, balance speed and accuracy effectively. Traditional K-Nearest Neighbors (KNN) can be slow with large or complex data, but ANN provides quicker results by focusing on approximations. KD-Trees helps by organizing the data into smaller, manageable parts for faster searches. This makes ANN ideal for real-time applications like recommendations and image searches that requires very low latency.

I have implemented this as a Python package called approxKD. Check out the details below:

--

--