Exploratory Data Analysis using Pyspark on Retail Sales Data

Soyoung L
7 min readJan 7, 2024

--

In the realm of burgeoning data sizes, data scientists are turning to PySpark over Pandas for its scalability. Join me today as I delve into the fundamentals of Exploratory Data Analysis (EDA) using PySpark, specifically focusing on Online Retail Sales Data. This walkthrough will take place within the dynamic environment of Databricks.

Photo by The Nix Company on Unsplash

1. How does the overall data look like?

There are multiple ways to look at the data — display() or show()

display(df)
df.show(10) # you can choose how many rows you want to show 

If you want to look at overall schema, data types. Then, use printSchema()

df.printSchema()

What if you want to understand the overall summary of the dataframe? Then, Pyspark offers describe() function which is equivalent to summary()function in Pandas.

df.describe()

describe() provides summary statistics, such as mean, stddev, min and max for each numeric columns. Just by looking at this summary, we observe that there are outliers in Quantity and UnitPrice. Let’s take a look at them!

2. Check & Remove Outliers

Outlier handling is another topic by itself. In this section, I’ll illustrate a basic approach to address outliers by removing values that significantly deviate from the mean and standard deviation of a column.

Let’s first calculate the means and standard deviations of Quantity and UnitPrice. Please, note that you have to import mean and stddev from pyspark.sql.functions to perform these calculations.

from pyspark.sql.functions import mean, stddev

qty_mean = df.select(mean('Quantity')).first()[0]
qty_std = df.select(stddev('Quantity')).first()[0]
price_mean = df.select(mean('UnitPrice')).first()[0]
price_std = df.select(stddev('UnitPrice')).first()[0]

It’s important to retrieve the calculated value using first()[0] after computing the mean. This is necessary because df.select(mean('Quantity')) returns a DataFrame with a single row and column. By using first(), you access the first row, and [0] retrieves the value located at the first index of that row.

Next, I applied the filter() function to identify rows significantly distant from the mean. In this case, 'Quantity' values less than 3 standard deviations away from the mean are considered potential outliers based on standard statistical descriptions. However, in real-world scenarios, gaining context from stakeholders is vital in determining true outliers. For this tutorial's scope, I'm presuming these instances as outliers.

display(df.filter(df['Quantity'] < - (qty_mean + 3 * qty_std)).groupBy('StockCode','Description').count())

I conducted similar procedures on the ‘UnitPrice’ column, subsequently removing 0.5% of the data based on this criterion.”

display(df.filter(df['UnitPrice'] >  (price_mean + 4 * price_std)).groupBy('Description').count())
new_df = df.filter((df['UnitPrice'] > 0) & ( df['UnitPrice'] < (price_mean + 4 * price_std)) & (df['Quantity'] > - (qty_mean + 3 * qty_std)))
display(new_df.describe())

Before the removal of outliers, the mean for ‘UnitPrice’ stood at 4.6 with a standard deviation of 97, while ‘Quantity’ showed a mean of 9.6 with a standard deviation of 218. Upon removing the outliers, ‘UnitPrice’ exhibited a mean of 3.6 with a standard deviation of 8.3, and ‘Quantity’ showed a mean of 10.2 with a standard deviation of 154. These significant changes highlight how much outliers affected these statistical measures.

# removed less than 0.5% of data 
100 * (1 - new_df.count()/df.count())

Please, ensure the extent of outlier removal by comparing the row count of the original dataframe with that of the modified dataframe.

3. Count Missing Values

In performing data analysis, we have to check how many missing values in the data for many reasons. I checked the missing values before removing outliers vs. after removing outliers.

from pyspark.sql.functions import col, sum 
# Count missing values in each column
missing_counts = df.select([sum(col(c).isNull().cast('int')).alias(c) for c in df.columns])

# Show the count of missing values
missing_counts.show()

In this instance, the missing values in the ‘Description’ column were also eliminated during the outlier removal process. This suggests that rows lacking proper descriptions were considered outliers and subsequently removed in the previous step.

Again, you need to import col and sum from pyspark.sql.functions before using them

4. What are popular & not-popular items?

I want to find the top 10 popular items and the bottom 10 items by revenue. First of all, let’s see how many unique items that this online e-commerce shop sells. In order to count distinct values in a column in Pyspark, you can select() the column, then apply distinct() and count(). This is equivalent to nunique() in Pandas. There are total of 3,937 unique items.

new_df.select('StockCode').distinct().count()

For this, I created a ‘Revenue’ column by multiplying ‘Quantity’ and ‘UnitPrice’. You can create a new column using withColumn()in Pyspark.

from pyspark.sql.functions import col

# create a new column Revenue = Quantity * UnitPrice
new_df = new_df.withColumn('Revenue', col('Quantity') * col('UnitPrice'))
display(new_df)

Since I’m only interested in overall revenue by each item, let’s groupBy() data by StockCode which is a code that represents each item. Then, withColumnRenamed() is used to rename columns.


product_sales_summary = new_df.groupBy('StockCode', 'Description').sum()
# Rename Columns
product_sales_summary = product_sales_summary.withColumnRenamed('sum(Quantity)', 'ItemQuantity')\
.withColumnRenamed('sum(Revenue)','ItemRevenue')
# Select necessary columns only
product_sales_summary = product_sales_summary.select('StockCode','Description','ItemQuantity','ItemRevenue')
display(product_sales_summary)

Each row now signifies the total quantity of each item sold and the total revenue generated by each item within this dataset. To focus solely on the top 10 performing items, the DataFrame has been sorted in descending order by ‘Revenue’ using orderBy() combined with desc(). This ordering allows us to identify the top-performing items based on revenue

from pyspark.sql.functions import desc

product_sales_summary = new_df.groupBy('StockCode', 'Description').sum()
# Rename Columns
product_sales_summary = product_sales_summary.withColumnRenamed('sum(Quantity)', 'ItemQuantity')\
.withColumnRenamed('sum(Revenue)','ItemRevenue')
# Select necessary columns only
product_sales_summary = product_sales_summary.select('StockCode','Description','ItemQuantity','ItemRevenue').orderBy(desc('ItemRevenue'))
display(product_sales_summary)

However, interpreting the relative performance of these items solely from ‘ItemRevenue’ — an absolute sales value — poses a challenge. To address this, let’s introduce a new column that represents the revenue percentage. This involves a two-step process: first, calculate the total revenue across all items, and then divide the ‘ItemRevenue’ column by this aggregate value.

from pyspark.sql.functions import col, sum, round

# Calculate the total revenue first
total_revenue = product_sales_summary.select(sum('ItemRevenue')).first()[0]

product_sales_summary = product_sales_summary.withColumn('PctItemRevenue', round((col('ItemRevenue')/total_revenue) * 100,4))
display(product_sales_summary.orderBy(desc('ItemRevenue')))

The top 3 items collectively account for just around 5% of the total revenue, while all other top-performing items individually contribute less than 1% to the overall revenue.

Now, let’s take a look at less popular items

# By default, orderBy is sorted in ascending order
display(product_sales_summary.orderBy('ItemRevenue'))

After reviewing these entries, we’ve identified additional outliers for removal; the initial five items appear to include charges, discounts, and other non-standard entries. Pending confirmation with stakeholders, we can exclude these items from subsequent analyses. The following code removes the specific rows based on the ‘Description’ column values.

# remove if 'Description' is Manual, Discount, Samples, Charges or Commission
new_df = new_df.filter(~col('Description').isin('Manual','Discount','SAMPLES','Bank Charges','CRUK Commission'))

5. Summary

  • display() or show() : to look at the data
  • printSchema() : to look at the data type
  • describe() : to look at the summary statistics of data
  • df.select(mean('col_name')) : to calculate the mean of a column
  • df.select(stddev('col_name')) : to calculate the standard deviation of a column
  • df.filter(condition) : to filter the data set with a condition
  • df.select('col_name').distinct().count() : to count unique values in a column
  • df.withColumn() : to create a new column
  • df.withColumnRenamed() : to rename a column name
  • df.groupBy() : to group by dataset, is used with an aggregate function
  • df.select() : to select specific columns

6. Reference

— Data Source: https://www.kaggle.com/datasets/ulrikthygepedersen/online-retail-dataset

--

--

Soyoung L

Machine Learning in Retail | Data Scientist 5 YOE | Master's Statistics from CMU and Bachelor's from UC Berkeley | https://www.youtube.com/@DataScientistLuna