K-means clustering with Neo4j

Nathan Smith
Neo4j Developer Blog
12 min readAug 27, 2020

When faced with a set of many items, we often want to organize them into groups. Simplification from many individual examples to higher-level groups of similar items can help us wrap our heads around what might otherwise be an overwhelming amount of data. If we already know the different categories that we want to group our data into, and we have examples of each category, we can use supervised machine learning to classify our data. If we don’t know the categories in advance, we can use unsupervised machine learning to discover new groupings from our data.

Photo by Jarek Ceborski on Unsplash

The community detection algorithms that come in Neo4j’s Graph Data Science library are one way to apply unsupervised machine learning. We can use them to find groupings based upon relationships among items. If we need to group items based on similar properties instead of relationships, we can use clustering algorithms such as k-means. In this article, we’ll look at how we could implement k-means and the k-means++ seeding algorithm in Neo4j’s Cypher query language. We’ll also compare the results of k-means to clusters we could generate through the Louvain community detection algorithm.

I like k-means because it is a powerful algorithm that I also find easy to understand. It has a long history in data analytics. Hugo Steinhaus was one of the first to propose the idea in the 1950s.

You can use k-means when you believe that your data could be grouped into a known number of natural clusters. The algorithm assigns each example in the data to a cluster of items with similar values for the properties that you are considering. A very simple example could be your data points are the weights and heights of some volunteers, and you want to determine which categories they would map to, based on these data point values.

We decide in advance how many clusters we want to divide the data into. The number of clusters that we pick is called “k”, and that’s the “k” in k-means. Next, we select one example at random from our data set for each of the k clusters. These randomly-selected points become the centroids of our clusters. Each example in the data set is assigned to the cluster that corresponds to the nearest centroid. Then we move the centroid so that it sits at the average value along each dimension for the items in the cluster. This averaging step is where we get the “means” in k-means. We repeat this process of assigning examples to their closest centroid, and then shifting the centroids to respond to their new cluster members. In time the algorithm should converge to a stable state where no examples (or at most a very few examples) are switching clusters at the assignment step.

Implementing k-means in Neo4j sounds like a fun challenge. However, there are efficient, easy-to-use implementations of k-means like this one in Python and this one in R. Do we really need to build one in Neo4j? Let’s not be like the scientists in Jurassic Park — so preoccupied with whether or not they could, they didn’t stop to think if they should.

Photo by Amy-Leigh Barnard on Unsplash

I probably wouldn’t use a Neo4j-based implementation of k-means in production, but I think it could be worth your time to keep reading. If you are already familiar with k-means, thinking about implementing it in Neo4j could be a fun exercise to sharpen your Cypher skills. If you are already familiar with Cypher, using it to write k-means could deepen your understanding of unsupervised machine learning.

Implementing k-means

Start with a Neo4j sandbox. We can use the blank project for this example.

Choose the Blank Sandbox project to get started

The computer science department at the University of Eastern Finland has made some clustering datasets available online that we can use for this example available at http://cs.joensuu.fi/sipu/datasets/ In particular, we’ll use synthetic data from the paper by P. Fränti and O. Virmajoki, “Iterative shrinking method for clustering problems”, Pattern Recognition, 39 (5), 761–765, May 2006.

The data consists of 5,000 examples that fall into 15 clusters. The file contains an x and y coordinate for each example. We’ll store the coordinates for each example using Neo4j’s point data type. This will make it easy for us to calculate the distance between examples. Load the data with this Cypher command.

LOAD CSV FROM "http://cs.joensuu.fi/sipu/datasets/s3.txt" AS row
WITH row,
toInteger(trim(substring(row[0], 0, 11))) AS x, toInteger(trim(substring(row[0], 12))) AS y
CREATE (i:Item {location:point({x:x, y:y})})
RETURN COUNT(*)

We know from the documentation of our data set that there are 15 clusters. Let’s create 15 random centroids. Each one has a location taken from an item and an iterations property we can use to keep track of how many times we have moved the centroid.

MATCH (i:Item)
WITH i, rand() AS sortOrder
ORDER BY sortOrder
LIMIT 15
CREATE (c:Centroid)
SET c.location = i.location,
c.iterations = 0
RETURN *

Let’s assign each cluster a number so that we can tell them apart.

MATCH (c:Centroid)
WITH collect(c) AS centroids
UNWIND range(0, 15) AS num
SET (centroids[num]).clusterNumber = num + 1
RETURN centroids[num]

I exported the items and the centroids to CSV and plotted them to show our starting point.

MATCH (n) 
RETURN labels(n)[0] AS nodeType,
n.location.x AS x,
n.location.y AS y
K-means random starting point

Cypher’s distance function will calculate the distance between each point in the x-y plane. We can use this Cypher query to assign each item the cluster number for the nearest centroid.

//Assign each item to the cluster with the nearest centroid
MATCH (i:Item), (c:Centroid)
WITH i, c ORDER BY distance(i.location, c.location)
WITH i, collect(c) AS centroids
SET i.clusterNumber = centroids[0].clusterNumber

Now we recalculate the location of each cluster’s centroid to be the mean x and y values for all points in the cluster. I also want to keep track of how many times I have iterated through the algorithm, so I will increment the “iterations” property on each centroid.

//Move each centroid to the mean of the currently assigned items
MATCH (i:Item), (c:Centroid)
WHERE i.clusterNumber = c.clusterNumber
WITH c, avg(i.location.x) AS newX, avg(i.location.y) AS newY
SET c.location = point({x:newX, y:newY}),
c.iterations = c.iterations + 1

The last two statements make one iteration of the k-means algorithm. We can combine them using Cypher’s with clause. We can also tweak the first part of the query to return the number of items that were reassigned to a new cluster with this iteration. The combined query looks like this.

//Assign each item to the cluster with the nearest centroid
MATCH (i:Item), (c:Centroid)
WITH i, c ORDER BY distance(i.location, c.location)
WITH i, i.clusterNumber AS oldClusterNumber, collect(c) AS centroids
SET i.clusterNumber = centroids[0].clusterNumber
WITH i, oldClusterNumber
WHERE i.clusterNumber <> oldClusterNumber
WITH count(*) AS changedCount
//Move each centroid to the mean of the currently assigned items
MATCH (i:Item), (c:Centroid)
WHERE i.clusterNumber = c.clusterNumber
WITH changedCount, c, avg(i.location.x) AS newX, avg(i.location.y) AS newY
SET c.location = point({x:newX, y:newY}),
c.iterations = c.iterations + 1
RETURN changedCount, c.iterations AS iterations
LIMIT 1

Finally, we can wrap this statement in an apoc.periodic.commit() function to run it repeatedly until no points switch clusters after an iteration. Keep in mind that while the algorithm usually converges to a state where no points are switching clusters after each iteration, that is not guaranteed to happen. We set a hard stop after 20 iterations if the algorithm has not converged by then.

call apoc.periodic.commit(
"MATCH (i:Item), (c:Centroid)
WITH i, c ORDER BY distance(i.location, c.location)
WITH i, i.clusterNumber AS oldClusterNumber, collect(c) AS
centroids
SET i.clusterNumber = centroids[0].clusterNumber
WITH i, oldClusterNumber
WHERE i.clusterNumber <> oldClusterNumber
WITH count(*) AS changedCount
MATCH (i:Item), (c:Centroid)
WHERE i.clusterNumber = c.clusterNumber
WITH changedCount, c, avg(i.location.x) AS newX,
avg(i.location.y) AS newY
SET c.location = point({x:newX, y:newY}),
c.iterations = c.iterations + 1
RETURN CASE WHEN c.iterations < 20 THEN changedCount ELSE 0 END
LIMIT 1"
,{})

I exported and plotted the clusters and centroids. We’d like to find clusters that are cohesive and distinct from each other. Overall, it looks like the algorithm did a pretty good job, but there are a few clusters that we would draw differently. If we started over from a different set of random locations for our centroids, we might get a slightly better or worse result.

K-means clustering results

Implementing k-means++

Since the starting point for our centroids influences the quality of our resulting clusters, we’d like to find a way to avoid a beginning configuration that will lead to a dead end. David Arthur and Sergei Vassilvitskii developed a straightforward seeding algorithm called k-means++ that helps pick good starting centroids for k-means. We start by picking any item at random to be the first centroid. Items are selected to serve as subsequent centroids with a probability proportional to the square of the distance to their closest existing centroid.

We can code this out in Cypher and get a better sense of how it works in practice. First let’s delete our old centroids.

MATCH (c:Centroid) DELETE c

Now, select a single item at random and use its location for our first new centroid.

MATCH (i:Item)
WITH i, rand() AS sortOrder
ORDER BY sortOrder
LIMIT 1
CREATE (c:Centroid)
SET c.location = i.location,
c.iterations= 0
RETURN *

We’ll build up the code to execute the rest of the algorithm one step at a time. First, we need to know the distance from each point to the closest existing centroid. Here’s a fragment of Cypher to do that. Don’t try to run it by itself. We’ll assemble multiple fragments to create the whole algorithm.

//Step 1. find the closest centroid to each item
MATCH (i:Item), (c:Centroid)
WITH i, min(distance(i.location, c.location)) AS minDistance

We’ll extend the previous statement to collect all the items into one list, and all the squared distances into a second list. We’ll end up with paired arrays, with the order of the items corresponding to the order of the squared distances.

//Step 2. Collect the items and squared distances into lists
WITH collect(i) AS items, collect(minDistance^2) AS squareDistances

We’d like to get a running total of the values in the squared distances list. Cypher list comprehensions give us a concise syntax for accomplishing this.

Imagine our list of square distances starts out like this:

[4, 9, 1, 25…]

Our first list comprehension will give us a list of progressively longer sublists. For each number idx in range (1, size (squareDistances) in the range from 1 to the size of the squareDistances list return a subset of the list beginning from the start of the list and going for idx values squareDistances[..idx] The code snippet looks like this:

[idx in range(1, size(squareDistances)) | squareDistances[..idx]]

Our output would start out like this:

[[4], [4, 9], [4, 9, 1], [4, 9, 1, 25]...]

The second list comprehension will add up the values in each of the sublists, giving us a running total. For each subList in the listOfLists, start a variable called sum at 0. Add each value i in the sublist to sum. The code snippet looks like this:

[subList in listOfLists | reduce(sum=0.0, i in subList| sum + i)]

Our output would start out like this:

[ 4, 13, 14, 39 …]

Here’s our query so far.

//Step 1. Find the closest centroid to each item
MATCH (i:Item), (c:Centroid)
WITH i, min(distance(i.location, c.location)) AS minDistance
//Step 2. Collect the items and squared distances into lists
WITH collect(i) AS items, collect(minDistance^2) AS squareDistances
//Step 3. Turn the squared distances into running totals
WITH items, [idx in range(1, size(squareDistances)) | squareDistances[..idx]] AS listOfLists
WITH items, [subList in listOfLists | reduce(sum=0.0, i in subList| sum + i)] AS runningTotals

We can think of our running totals object as a number line divided up into segments corresponding to the square distances for our items. We select a number at random along this number line, and then select the first segment with a starting point to the left of our random number. We choose the corresponding item from our item list that goes with the segment. The probability of selecting any item being selected is proportional to the squared distance to an existing centroid.

//Step 4. Select an item with probability proportional 
//to square distance
WITH items, runningTotals, rand()*runningTotals[-1] AS cutoff
WITH items, runningTotals, [v in runningTotals where v > cutoff][0] as selected, cutoff
WITH items, selected, apoc.coll.indexOf(thresholds, selected) AS selectedIndex
WITH items[selectedIndex] AS selectedItem

We create a new centroid at the selected item’s location, then repeat the algorithm until we have the desired number of centroids. The complete code for the algorithm is below.

//K-means++               
CALL apoc.periodic.commit(
"// Step 1. Find the closest centroid to each item
MATCH (i:Item), (c:Centroid)
WITH i, min(distance(i.location, c.location)) AS minDistance
// Step 2. Collect the items and squared distances into lists
WITH collect(i) AS items, collect(minDistance^2) AS squareDistances
// Step 3. Turn the squared distances into running totals
WITH items, [idx in range(1, size(squareDistances)) | squareDistances[..idx]] AS listOfLists
WITH items, [subList in listOfLists | reduce(sum=0.0, i in subList| sum + i)] AS runningTotals
// Step 4. Select an item with probability proportional
// to square distance
WITH items, runningTotals, rand()*runningTotals[-1] AS cutoff
WITH items, runningTotals, [v in runningTotals where v > cutoff][0] AS selected, cutoff
WITH items, selected, apoc.coll.indexOf(thresholds, selected) AS selectedIndex
WITH items[selectedIndex] AS selectedItem
// Step 5. Create the centroid
CREATE (cnt:Centroid)
set cnt.location = selectedItem.location,
c.iterations = 0
WITH selectedItem
LIMIT 1
// Step 6. Return the number of centroids left to create
MATCH (c:Centroid)
RETURN 15 - count(c)
")

Here are the centroids selected by a run of k-means++.

Centroids chosen by k-means++

Here are how the clusters turned out after running k-means from that starting point.

It may not seem like a huge improvement. There’s still randomness involved in this process, and starting from k-means++ doesn’t guarantee the best possible k-means clusters every time. It helps avoid the worst ones though. It also helps our model converge with fewer iterations.

Compare k-means with Louvain

We said at the outset that clustering algorithms work based on property similarity, and community detection algorithms work based on relationships. However, we can play with that distinction by creating relationships among nodes that have similar properties. Then, we can use the Louvain community detection module to identify communities based on those relationships.

We’ll be using Neo4j’s Graph Data Science library. Our first step will be to load the graph into memory. In our data, the x and y axes are on a scale of 0 to 1,000,000. I expect that most items in the same cluster should be no more than 200,000 units apart. We will project relationships between any pairs of nodes closer than 200,000 units. The Louvain algorithm can use a weight on each relationship. After some experimentation, I found that weighting each relationship as the inverse square of the distance between the points worked well. This inverse square weighting shows up in physics, including the way sound gets quieter as it travels across distances. I imagine the tiny voices of nodes echoing across the void.

Here’s the code to load the graph.

CALL gds.graph.create.cypher(
'clustering-demo',
'MATCH (i:Item) RETURN id(i) AS id',
'MATCH (i1:Item), (i2:Item)
WHERE distance(i1.location, i2.location) < 200000
AND id(i1) <> id(i2)
RETURN id(i1) AS source, id(i2) AS target,
(distance(i1.location, i2.location)) ^(-2) AS similarity')

Now we can run Louvain community detection and count the number of examples in each community.

CALL gds.louvain.stream('clustering-demo', {relationshipWeightProperty:'similarity'})
YIELD nodeId, communityId
RETURN communityId, count(nodeId)
Community detection results

Unlike k-means, I can’t specify the number of communities that I will get back from Louvain. This run of the algorithm gave me 22 communities. From the first five rows of the results, I can see that that a few of the communities have very few nodes.

Let’s write these communities back to the graph, then export and plot them.

CALL gds.louvain.write('clustering-demo', {relationshipWeightProperty:'similarity', writeProperty:'louvainCommunity'})
YIELD communityCount, modularities
RETURN communityCount, modularities
Louvain communities

The fifteen large clusters look really good. The small communities could be merged into their bigger neighbors.

I hope you have had fun exploring unsupervised machine learning from a few different angles in Neo4j. The tools of Cypher and the Graph Data Science library can help you see your data and the algorithms we use in a new light.

--

--

Nathan Smith
Neo4j Developer Blog

Senior Data Scientist at Neo4j. Organizer of the Kansas City Graph Databases Meetup.