[Project] Golden Ticket to Big Data: Exploring Wonka’s Candy Sales with Spark
Join Willy Wonka’s team and learn how to use Spark to analyze candy sales data from around the world
Imagine you opened your chocolate bar and found a golden ticket. What you didn’t know is that instead of winning a trip, you won an internship to work for Willy Wonka! Your first task is showing that you know how to handle Spark by creating an application of your choice.
To avoid doing the “Hello World!” of distributed computing (counting words), let’s try to be a little bit more creative. Let’s analyze some data about Wonka’s global candy sales. To make it stand out a little bit, let’s use data in the 100k range. We’ll do this only to enjoy Spark’s capabilities.
Simulating Data
Before we start anything, let’s import the necessary libraries and create an empty list to store our data.
import pandas as pd
import numpy as np
data = []
Let’s simulate the data by making a Python script that will create entries that represent how much candy sold in a country. To do this, first we should create a list containing the countries where Wonka sells his candy, and a list with the types of candy he sells.
countries = ["USA", "MEX", "CAN", "DEU", "ITA", "FRA", "CHN", "RUS", "SAU", "ARE", "GBR", "TUR", "IND", "BRA"]
candy_types = ["chocolate bar", "white chocolate bar", "dark chocolate bar", "blueberry bubblegum", "caramel popcorn", "peanut butter pops", "chocolate cookies", "butter cookies", "gummy bears", "lollipops"]
Now let’s create a for-loop that runs 100k times or more (depends on how many you want). Then we will randomly create entries by matching a country, a candy and a random number between 100,000 and 10,000,000 simulating sales. Finally we’ll append the data as a dictionary.
for i in range(100000):
country = np.random.choice(countries)
candy = np.random.choice(candy_types)
sales = np.random.randint(100000, 10000000)
data.append({'country': country, 'candy': candy, 'sales': sales})
Finally, let’s convert our data list to a Pandas DataFrame, and then save it as a CSV file.
df = pd.DataFrame(data)
df.to_csv('candy_sales.csv', index=False) # We don't need to store the index
Yay! Now we have some data to work with. Remember that if you want, you can change the amount of entries in our file and the sales amounts as well.
Working With Our Data
This is where the fun begins. As we did before, let’s start by importing the necessary libraries.
import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import count
To get the ball rolling, we should create a Spark session using it’s API. Remember that there can only be one SparkSession per JVM (JavaVirtualMachine), so in the rare case that it doesn’t exist, create an instance first.
spark = (SparkSession
.builder
.appName("CandySalesCount")
.getOrCreate())
Now we need to get the filename for our candy sales report. This part depends on personal preference. We can “hard-code” it, or we can ask through the command-line. This is how both options would look like.
# Command-line option
candy_sales_file = sys.argv[1]
# Hard-coded option
candy_sales_file = "./candy_sales.csv"
Next we should load our file into a Spark DataFrame using a CSV format. To do that we should tell Spark to infer the schema and that our file contains a header. This way Spark automatically identifies the column names.
candy_sales_df = (spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(candy_sales_file))
For this exercise, we’ll be using the high level APIs instead of RDDs. This is because the exercise is so simple, we want to focus on telling Spark what to do, and not how to do it. Also, this will allow us to chain function calls, since these functions return the same object.
count_candy_sales_df = (candy_sales_df
.select("country", "candy", "sales")
.groupBy("country", "candy")
.agg(count("sales").alias("Total"))
.orderBy("Total", ascending=False))
Let’s see the resulting aggregations for each country and it’s sales. A total count of candy sales per state. Remember that show() is an action, so that means that it will trigger the query we just wrote and it will execute.
count_candy_sales_df.show(n=60, truncate=False)
print(f"Total Rows = {count_candy_sales_df.count()}")
+-------+-------------------+-----+
|country|candy |Total|
+-------+-------------------+-----+
|ITA |butter cookies |779 |
|CAN |white chocolate bar|769 |
|ARE |blueberry bubblegum|767 |
|CAN |lollipops |762 |
|MEX |caramel popcorn |760 |
|IND |butter cookies |760 |
|RUS |butter cookies |757 |
|FRA |white chocolate bar|756 |
|ARE |gummy bears |755 |
|TUR |butter cookies |753 |
|BRA |gummy bears |752 |
|ARE |lollipops |751 |
|ITA |white chocolate bar|750 |
|TUR |gummy bears |750 |
|CHN |butter cookies |749 |
|RUS |lollipops |749 |
|TUR |peanut butter pops |748 |
|CHN |caramel popcorn |748 |
|ITA |lollipops |745 |
|FRA |gummy bears |744 |
|CHN |dark chocolate bar |742 |
|IND |peanut butter pops |742 |
...
+-------+-------------------+-----+
This gives us an overview of how many occurrences a candy was sold in each country. This doesn’t mean that 779 butter cookies were sold in Italy though, it means there were 779 sales occurrences (actual sales must be in the millions). If we wanted to dig a bit deeper, we can see the data for a specific country with the following code.
ita_count_candy_sales_df = (candy_sales_df
.select("country", "candy", "sales")
.where(candy_sales_df.country == "ITA")
.groupBy("country", "candy")
.agg(count("sales").alias("Total"))
.orderBy("Total", ascending=False))
ita_count_candy_sales_df.show(n=10, truncate=False)
+-------+-------------------+-----+
|country|candy |Total|
+-------+-------------------+-----+
|ITA |butter cookies |779 |
|ITA |white chocolate bar|750 |
|ITA |lollipops |745 |
|ITA |blueberry bubblegum|741 |
|ITA |gummy bears |723 |
|ITA |chocolate cookies |717 |
|ITA |caramel popcorn |702 |
|ITA |dark chocolate bar |701 |
|ITA |peanut butter pops |696 |
|ITA |chocolate bar |696 |
+-------+-------------------+-----+
Remember to stop the spark session when you’re finished. As discussed in a previous article, each session holds resources like memory, CPU and connections to other services. Stopping the session will release all these resources. Not doing so will keep these resources inaccessible to other processes which might impact performance. In a shared environment this would be even worse, because you could impact other users who may need the same resources. So just as a good practice, always stop your Spark Session when you’re done using it.
spark.stop()
What’s next?
People at Wonka were impressed with your Spark skills. They said that they’ve been looking for someone with your capabilities.
If you want to get hired, consider doing the following:
- Get the total sales for each candy
- Get the total sales for each candy by country
- Get the total sales for each country
To solve these exercises, you might want to research a little bit into how to use other functions than count.