Easy K-Means Clustering with Jupyter, Spark 2.4, and Docker

Clusters. Yum.

So you want to cluster some data. There are plenty of YouTube videos that explain clustering better than I ever could. But you’re not here for that. You’re here to bake some data clusters using the best tools in the best way.

Let’s say you have several million (which, today, is a small amount) rows (or “points”) and you would like to learn how the columns (or “attributes”) in each row (or “point”) are related. You might already have some hunches. You might be completely unaware of something profound. To find out, start with clustering. Clusters divide your dataset into smaller chunks that you can then further analyze via, say, statistics and perhaps even clustering (it’s clusters all the way down)!

The data set herein contains purchase data with some demographics about the buyers, such as age. The data is fake. There aren’t enough data points to make any sense of it. Instead of doing real data science, the purpose of this exercise is to install everything you need from scratch and write some Scala code in a Jupyter notebook.

Why Jupyter and not Zeppelin?

I wasted about a day on Apache Zeppelin learning that version 0.8, the latest as of this writing (published last June), uses an older version of Spark. Since it’s November 2018, we naturally want to use the latest version of Spark MLlib (now called Spark ML) that uses dataframes (aka datasets) instead of RDDs by default.

Please get your act together, Apache Zeppelin team!



  1. Copy this data file to your home directory as data.csv

2. Install Docker

3. Start a terminal and enter:

docker run -v $PWD:/home/jovyan/work --user root -e GRANT_SUDO=yes \
-p 8888:8888 jupyter/all-spark-notebook

4. After a long wait, Jupyter will output a URL to the console. Copy it into your browser.

5. Create a new notebook using “Apache Toree — Scala”

6. Paste the following code into the first paragraph:

import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.{Pipeline}
import org.apache.spark.sql.functions.col
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.types._
// Data file:
// age(int)
// gender('M' or 'F')
// days since prior purchase(int)
// month(string, 3 char month abbreviation)
// amount(float)
val schema = StructType(Array(
StructField("age", DoubleType, true),
StructField("gender", StringType, true),
StructField("days", DoubleType, true),
StructField("month", StringType, true),
StructField("amount", DoubleType, true)))
// Read the input file
val df = spark.read.format("csv")
.option("header", "false")
.option("inferSchema", "true")
val gindexer = new StringIndexer().setInputCol("gender").setOutputCol("genderIndex")
val gencoder = new OneHotEncoder().setInputCol("genderIndex").setOutputCol("genderVec")
val mindexer = new StringIndexer().setInputCol("month").setOutputCol("monthIndex")
val mencoder = new OneHotEncoder().setInputCol("monthIndex").setOutputCol("monthVec")
// Specify the fields used for clustering
val assembler = new VectorAssembler()
// k-means model with two clusters
val kmeans = new KMeans().setK(2).setSeed(1L)
// Create a pipeline
val pipeline = new Pipeline().setStages(Array(gindexer, gencoder, mindexer, mencoder, assembler, kmeans))
// Run the pipeline
val kMeansPredictionModel = pipeline.fit(df)
// Create a dataframe with the transformed input plus a
// field named 'prediction' containing the cluster number
val predictionResult = kMeansPredictionModel.transform(df)

7. Run the first paragraph (shift-enter)

8. predictionResult is a dataframe, so you can analyze it further. The “prediction” column contains a cluster number. In a new paragraph, enter the following and run it (shift-enter):

import org.apache.spark.sql.functions._
.agg(count("amount"), sum("amount"), min("amount"),
max("amount"), stddev_pop("amount"))


|prediction|count(amount)| sum(amount)|min(amount)|max(amount)|stddev_pop(amount)|
| 1| 3| 1396.93| 361.62| 596.99| 98.0125251633121|
| 0| 7|635.6199999999999| 6.93| 269.66| 81.38010422250211|

Next Steps

  • Try running the two paragraphs again with a larger K value, such as 5
  • Learn PySpark. For example, send predictionResult to Pandas and do something awesome with a tiny amount of coding!


Use Vegas for Graphing

The next thing you’ll want to do is visualize your clusters. If you want to stick with Scala, Vegas is a good option. Run this once in a new paragraph:

%AddDeps org.vegas-viz vegas_2.11 0.3.11 — transitive

The latest version of Vegas is 0.3.11. You will need to change this as newer versions are released.

Publish Your Tips and Tricks

This story took two solid days to research and write. It’s your turn to do something cool with Spark and share your adventures.

Have fun!

The journey is the reward!