In the realm of burgeoning data sizes, data scientists are turning to PySpark over Pandas for its scalability. Join me today as I delve into the fundamentals of Exploratory Data Analysis (EDA) using PySpark, specifically focusing on Online Retail Sales Data. This walkthrough will take place within the dynamic environment of Databricks.
1. How does the overall data look like?
There are multiple ways to look at the data — display()
or show()
display(df)
df.show(10) # you can choose how many rows you want to show
If you want to look at overall schema, data types. Then, use printSchema()
df.printSchema()
What if you want to understand the overall summary of the dataframe? Then, Pyspark offers describe()
function which is equivalent to summary()
function in Pandas.
df.describe()
describe()
provides summary statistics, such as mean, stddev, min and max for each numeric columns. Just by looking at this summary, we observe that there are outliers in Quantity and UnitPrice. Let’s take a look at them!
2. Check & Remove Outliers
Outlier handling is another topic by itself. In this section, I’ll illustrate a basic approach to address outliers by removing values that significantly deviate from the mean and standard deviation of a column.
Let’s first calculate the means and standard deviations of Quantity and UnitPrice. Please, note that you have to import mean and stddev from pyspark.sql.functions to perform these calculations.
from pyspark.sql.functions import mean, stddev
qty_mean = df.select(mean('Quantity')).first()[0]
qty_std = df.select(stddev('Quantity')).first()[0]
price_mean = df.select(mean('UnitPrice')).first()[0]
price_std = df.select(stddev('UnitPrice')).first()[0]
It’s important to retrieve the calculated value using first()[0]
after computing the mean. This is necessary because df.select(mean('Quantity'))
returns a DataFrame with a single row and column. By using first()
, you access the first row, and [0]
retrieves the value located at the first index of that row.
Next, I applied the filter()
function to identify rows significantly distant from the mean. In this case, 'Quantity' values less than 3 standard deviations away from the mean are considered potential outliers based on standard statistical descriptions. However, in real-world scenarios, gaining context from stakeholders is vital in determining true outliers. For this tutorial's scope, I'm presuming these instances as outliers.
display(df.filter(df['Quantity'] < - (qty_mean + 3 * qty_std)).groupBy('StockCode','Description').count())
I conducted similar procedures on the ‘UnitPrice’ column, subsequently removing 0.5% of the data based on this criterion.”
display(df.filter(df['UnitPrice'] > (price_mean + 4 * price_std)).groupBy('Description').count())
new_df = df.filter((df['UnitPrice'] > 0) & ( df['UnitPrice'] < (price_mean + 4 * price_std)) & (df['Quantity'] > - (qty_mean + 3 * qty_std)))
display(new_df.describe())
Before the removal of outliers, the mean for ‘UnitPrice’ stood at 4.6 with a standard deviation of 97, while ‘Quantity’ showed a mean of 9.6 with a standard deviation of 218. Upon removing the outliers, ‘UnitPrice’ exhibited a mean of 3.6 with a standard deviation of 8.3, and ‘Quantity’ showed a mean of 10.2 with a standard deviation of 154. These significant changes highlight how much outliers affected these statistical measures.
# removed less than 0.5% of data
100 * (1 - new_df.count()/df.count())
Please, ensure the extent of outlier removal by comparing the row count of the original dataframe with that of the modified dataframe.
3. Count Missing Values
In performing data analysis, we have to check how many missing values in the data for many reasons. I checked the missing values before removing outliers vs. after removing outliers.
from pyspark.sql.functions import col, sum
# Count missing values in each column
missing_counts = df.select([sum(col(c).isNull().cast('int')).alias(c) for c in df.columns])
# Show the count of missing values
missing_counts.show()
In this instance, the missing values in the ‘Description’ column were also eliminated during the outlier removal process. This suggests that rows lacking proper descriptions were considered outliers and subsequently removed in the previous step.
4. What are popular & not-popular items?
I want to find the top 10 popular items and the bottom 10 items by revenue. First of all, let’s see how many unique items that this online e-commerce shop sells. In order to count distinct values in a column in Pyspark, you can select()
the column, then apply distinct()
and count()
. This is equivalent to nunique() in Pandas. There are total of 3,937 unique items.
new_df.select('StockCode').distinct().count()
For this, I created a ‘Revenue’ column by multiplying ‘Quantity’ and ‘UnitPrice’. You can create a new column using withColumn()
in Pyspark.
from pyspark.sql.functions import col
# create a new column Revenue = Quantity * UnitPrice
new_df = new_df.withColumn('Revenue', col('Quantity') * col('UnitPrice'))
display(new_df)
Since I’m only interested in overall revenue by each item, let’s groupBy()
data by StockCode which is a code that represents each item. Then, withColumnRenamed()
is used to rename columns.
product_sales_summary = new_df.groupBy('StockCode', 'Description').sum()
# Rename Columns
product_sales_summary = product_sales_summary.withColumnRenamed('sum(Quantity)', 'ItemQuantity')\
.withColumnRenamed('sum(Revenue)','ItemRevenue')
# Select necessary columns only
product_sales_summary = product_sales_summary.select('StockCode','Description','ItemQuantity','ItemRevenue')
display(product_sales_summary)
Each row now signifies the total quantity of each item sold and the total revenue generated by each item within this dataset. To focus solely on the top 10 performing items, the DataFrame has been sorted in descending order by ‘Revenue’ using orderBy()
combined with desc()
. This ordering allows us to identify the top-performing items based on revenue
from pyspark.sql.functions import desc
product_sales_summary = new_df.groupBy('StockCode', 'Description').sum()
# Rename Columns
product_sales_summary = product_sales_summary.withColumnRenamed('sum(Quantity)', 'ItemQuantity')\
.withColumnRenamed('sum(Revenue)','ItemRevenue')
# Select necessary columns only
product_sales_summary = product_sales_summary.select('StockCode','Description','ItemQuantity','ItemRevenue').orderBy(desc('ItemRevenue'))
display(product_sales_summary)
However, interpreting the relative performance of these items solely from ‘ItemRevenue’ — an absolute sales value — poses a challenge. To address this, let’s introduce a new column that represents the revenue percentage. This involves a two-step process: first, calculate the total revenue across all items, and then divide the ‘ItemRevenue’ column by this aggregate value.
from pyspark.sql.functions import col, sum, round
# Calculate the total revenue first
total_revenue = product_sales_summary.select(sum('ItemRevenue')).first()[0]
product_sales_summary = product_sales_summary.withColumn('PctItemRevenue', round((col('ItemRevenue')/total_revenue) * 100,4))
display(product_sales_summary.orderBy(desc('ItemRevenue')))
The top 3 items collectively account for just around 5% of the total revenue, while all other top-performing items individually contribute less than 1% to the overall revenue.
Now, let’s take a look at less popular items
# By default, orderBy is sorted in ascending order
display(product_sales_summary.orderBy('ItemRevenue'))
After reviewing these entries, we’ve identified additional outliers for removal; the initial five items appear to include charges, discounts, and other non-standard entries. Pending confirmation with stakeholders, we can exclude these items from subsequent analyses. The following code removes the specific rows based on the ‘Description’ column values.
# remove if 'Description' is Manual, Discount, Samples, Charges or Commission
new_df = new_df.filter(~col('Description').isin('Manual','Discount','SAMPLES','Bank Charges','CRUK Commission'))
5. Summary
display()
orshow()
: to look at the dataprintSchema()
: to look at the data typedescribe()
: to look at the summary statistics of datadf.select(mean('col_name'))
: to calculate the mean of a columndf.select(stddev('col_name'))
: to calculate the standard deviation of a columndf.filter(condition)
: to filter the data set with a conditiondf.select('col_name').distinct().count()
: to count unique values in a columndf.withColumn()
: to create a new columndf.withColumnRenamed()
: to rename a column namedf.groupBy()
: to group by dataset, is used with an aggregate functiondf.select()
: to select specific columns
6. Reference
— Data Source: https://www.kaggle.com/datasets/ulrikthygepedersen/online-retail-dataset