Exploring Pyspark.ml for Machine Learning: Exploring Dataset

Sze Zhong LIM
Data And Beyond
Published in
8 min readDec 15, 2023

This article will be more about using PySpark to explore a dataset rather than using PySpark.ml. Just categorizing it into the same series for easier reference.

Photo by Holly Mandarich on Unsplash

Exploring a dataset is like exploring a jungle. You do not know what is lying in wait. You can only use your prior experience to navigate and find for points that will shine some light on the key insights.

If you haven’t installed PySpark on your local machine, here is a link on how to do it. Below is also a gist on how to create your own PySpark dataframe and how to use the df.printSchema() and df.show()

I will focus on:
1) Filtering the dataset down (as datasets are usually huge)
2) Check on Null Values
3) Obtain the quantiles / distinct counts for all column.
4) Target Understanding (Binary Class)

1. Filtering the Dataset

Assuming you had a huge dataset with partitions by different dates. The date was in the format of 2023-11-21 or yyyy-mm-dd. It isn’t hard to filter by df = df.filter(F.col('date_col')=='2023-11-21'). But what if I wanted the whole of November 2023 or the 3rd of every month?

Personally I use a date filer as below to modify accordingly.

def date_filter(sdf, date_col, target_year, target_month, target_day):
sdf = sdf.filter(
(F.col(date_col).substr(9,2)==target_day if target_day is not None else F.col(date_col).isNotNull())&
(F.col(date_col).substr(6,2)==target_month if target_month is not None else F.col(date_col).isNotNull())&
(F.col(date_col).substr(1,4)==target_year if target_year is not None else F.col(date_col).isNotNull())
)

countofdates = sdf.select(date_col).distinct().count()
print(f"Number of rows for subset with date filter: {sdf.count()}")
print(f"Number of distinct dates for subset with date filter: {countofdates}")
print("The partition(s) within this subset are:")
sdf.select(date_col).distinct().orderBy(date_col).show(countofdates, truncate=False)

return sdf

# Below would return the partition for 3rd Oct 2023.
sdf = date_filter(sdf, 'date_col', '2023', '10', '03')

# Below would return the partition for whole of Oct 2023.
sdf = date_filter(sdf, 'date_col', '2023', '10', None)

# Below would return the partition for every 1st of every month in 2023.
sdf = date_filter(sdf, 'date_col', '2023', None, '01')

# Below would return the partition for every 1st of every month in for every year in the dataset.
sdf = date_filter(sdf, 'date_col', None, None, '01')

If you have a range of dates you want to filter by, you may also use the code below:

def date_range(sdf, date_col, start_dt, end_dt):
sdf = sdf.filter(
(F.col(date_col) >= start_dt) & (F.col(date_col) < end_dt)
)

countofdates = sdf.select(date_col).distinct().count()
print(f"Number of rows for subset with date filter: {sdf.count()}")
print(f"Number of distinct dates for subset with date filter: {countofdates}")
print("The partition(s) within this subset are:")
sdf.select(date_col).distinct().orderBy(date_col).show(countofdates, truncate=False)

return sdf

# Will get the date range between 2023-10-01 till 2023-10-29.
sdf = date_range(sdf, 'date_col', '2023-10-01', '2023-10-30')

Sometimes there are too many columns in a dataset and you want to check out all the columns with datetypes. At the same time, we will change the date column datatype to StringType instead of date type. There are few reasons to do this.
1) We can use the functions above to filter thru the dates.
2) toPandas() will have some error when we convert pyspark datetime to pandas datetime. (At least on some legacy systems. Not sure if they fixed it already).

You may use the code below:

def date_columns(sdf, convert='Yes'):
date_col_list = [col_name for col_name, col_type in sdf.dtypes if 'date' in col_type.lower()]

if convert == 'Yes':
for col_name in date_col_list:
sdf = sdf.withColumn(col_name, F.col(col_name).cast(StringType()))
else:
pass

print(f"List of date type columns converted to string type are: {date_col_list}")

return sdf


# Will return list of datetype columns but no conversion of datatype.
sdf = date_columns(sdf, 0)

# Will return list of datetype columns and also convert all date columns to string type.
sdf = date_columns(sdf, 'Yes')

What if I have 3 different datasets and I want to check whether all of them have similar date partitions?


# We are going to assume you have extracted the partition list from your database.
dataset1_date_list = get_partition(dataset1).collect()
dataset2_date_list = get_partition(dataset2).collect()
dataset3_date_list = get_partition(dataset3).collect()

# Create a filter by Year first
def filter_dates_year(date_list, year):
return [date for date in date_list if date.startswith(year)]

set1 = set(filter_dates_year(dataset1_date_list, '2023'))
set2 = set(filter_dates_year(dataset2_date_list, '2023'))
set3 = set(filter_dates_year(dataset3_date_list, '2023'))

nameofset = ['set1', 'set2', 'set3']
varofset = [set1, set2, set3]

def set_diff(set_variable_list, set_name_list):
dict_diff = {}
for x in range(len(set_variable_list)):
for y in range(len(set_variable_list)):
set_diff = set_variable_list[x].difference(set_variable_list[y])
if len(set_diff) == 0:
pass
else:
colname = f"{set_name_list[x]}-{set_name_list[y]}"
dict_diff[colname] = set_diff
print(f"{set_name_list[x]} has something that {set_name_list[y]} does not.")

return dict_diff

print(set_diff(varofset, nameofset))

# Output will look something like below:
# {'set1':{'2023-06-23'}, 'set2':{'2023-06-23'}}

2. Check on Null Values

Checking for nulls is very important before running any exploratory processes on the dataset. It is also important after processing a dataset using joins etc.

# Code to check for Nulls
# Feel free to add on more conditions as you see fit.
null_sdf = sdf.select([F.count(F.when(F.col(x).contains("None") |\
F.col(x).contains("NULL") |\
F.col(x).isNull() |\
(F.col(x) == ""),x)).alias(x)
for x in df.columns])

null_sdf.toPandas()

# Output will be a pandas dataframe showing column name and number of nulls for each column.

3. Obtain Quantiles / Distinct Counts for all Columns

For a dataset that has millions of rows and hundreds of columns and we want to explore it in a systematic method, we can explore it with the code below. The code should check for continuous and categorical values.

For continuous values, it should check the different distinct value quantiles (different from quantiles) and the distinct values. For categorical values, it should check the number of distinct value and also the Min and Max. The reason why we are looking for distinct value quantiles is because in a class imbalance, we wouldn’t be able to explore much of the data. As such, it would be good to check the distinct values first, before diving deeper into the data. It will be like peeling an onion layer by layer.

Photo by K8 on Unsplash

Before running the code, it is recommended to run sdf.printSchema() to find out the datatypes within the dataset to be entered into the code below.

# Function to convert strings to numerical
def convert_str_to_numerical(sdf, column_list=None):
target_types=[StringType()]
if column_list == None:
column_list = sdf.columns
else:
pass
for col_name in column_list:
col_type = sdf.schema[col_name].dataType
if col_type in target_types:
sdf = sdf.withColumn(col_name, sdf[col_name].cast(FloatType()))
return sdf

# Function to obtain the quantity for a distinct value
def distinct_quantity(sdf, col_name, val_name):
value = sdf.filter(F.col(col_name)==val_name).count()
return value


def eda_wrapper(sdf, column_list=None):
if column_list == None:
column_list = sdf.columns
else:
pass

for column_name in column_list:
print(column_name)

a = sdf.select(column_name).distinct().orderBy(column_name).collect()
distinct_value = len(a)
print(f"Number of Distinct Values for {column_name} is {distinct_value}")

# To modify target_types based on the Schema.
# DecimalType(38,10) is an example. You can add or modify accordingly.
# target_types represent continuous variables.
target_types=[IntegerType(),DoubleType(),FloatType(),DecimalType(38,10)]

col_type = sdf.schema[column_name].dataType

# For Continuous Variables
if col_type in target_types and distinct_value > 2:
try:
listtocheck = a
if listtocheck[0][0] == None:
listtocheck.pop(0)
min_1 = np.quantile(listtocheck,0,interpolation='nearest')
max_1 = np.quantile(listtocheck,1,interpolation='nearest')
q1 = np.quantile(listtocheck,0.25,interpolation='nearest')
q2 = np.quantile(listtocheck,0.5,interpolation='nearest')
q3 = np.quantile(listtocheck,0.75,interpolation='nearest')
print(f"Min: {min1}; Max: {max_1}")
print(f"Q1: {q1}; Q2: {q2}; Q3: {q3}")
print("Null Present")
elif listtocheck[-1][0] == None:
listtocheck.pop(-1)
min_1 = np.quantile(listtocheck,0,interpolation='nearest')
max_1 = np.quantile(listtocheck,1,interpolation='nearest')
q1 = np.quantile(listtocheck,0.25,interpolation='nearest')
q2 = np.quantile(listtocheck,0.5,interpolation='nearest')
q3 = np.quantile(listtocheck,0.75,interpolation='nearest')
print(f"Min: {min1}; Max: {max_1}")
print(f"Q1: {q1}; Q2: {q2}; Q3: {q3}")
print("Null Present")
else:
min_1 = np.quantile(listtocheck,0,interpolation='nearest')
max_1 = np.quantile(listtocheck,1,interpolation='nearest')
q1 = np.quantile(listtocheck,0.25,interpolation='nearest')
q2 = np.quantile(listtocheck,0.5,interpolation='nearest')
q3 = np.quantile(listtocheck,0.75,interpolation='nearest')
print(f"Min: {min1}; Max: {max_1}")
print(f"Q1: {q1}; Q2: {q2}; Q3: {q3}")

except TypeError as E:
print(f"TypeError {E}")
print(f"Min: {a[0][0]}; Max: {a[-1][0]}")

# For Categorical Variables
else:
try:
listtocheck = a
if listtocheck[0][0] == None:
if len(listtocheck) > 1:
print(f"Min: {a[1][0]}; Max: {a[-1][0]}")
else:
print("Only value is Null")
print("Null Present")
elif listtocheck[-1][0] == None:
print(f"Min: {a[0][0]}; Max: {a[-2][0]}")
print("Null Present")
else:
print(f"Min: {a[0][0]}; Max: {a[-1][0]}")
except TypeError as E:
print(f"TypeError {E}")
print(f"Min: {a[0][0]}; Max: {a[-1][0]}")

# For continuous / categorical variables with less than 10 distinct values
if distinct_value < 10:
newlst = [x[0] for x in a]
newdict = {}
for x in newlst:
newdict[x] = distinct_quantity(sdf, column_name, x)
print(newdict)

print()

# Remember to convert strings to values beforehand if necessary
columns_to_convert = ['col_a', 'col_b']
sdf = convert_str_to_numerical(sdf, columns_to_convert)

# To run on all columns
eda_wrapper(sdf, column_list=None)

# To run on specific columns
cols = ['cola', 'colb', 'colc', 'cold']
eda_wrapper(sdf, cols)

For a more detailed look into the distribution of the distinct values (for those more than 10 distinct values previously), I used the code below:

def distinct_counts(sdf, col_name, distinct_value=None):
print(col_name)
a = sdf.select(col_name).distinct().orderBy(col_name).collect()
if distinct_value == None:
newlst = [x[0] for x in a]
newdict = {}
for x in newlst:
newdict[x] = distinct_quantity(sdf, column_name, x)
print(newdict)
else:
count = distinct_quantity(sdf, col_name, distinct_value)
print(f"{distinct_value}:{count}")

# To get ALL the distinct values and counts from 'col_a'
distinct_counts(sdf, 'col_a')

# To get the counts from distinct value 'Yes' from 'col_a'
distinct_counts(sdf, 'col_a', 'Yes')

4. Target Understanding (Binary Class)

In the case of a binary classification, where the targets are 1 or 0, we can quickly check for the effect / impact a column has on the classes. In the case of a class imbalance, where we want to train our model based on the minority 1, we want to identify feature values which can provide a significant difference to help the model in identifying the characteristics.

# Assume target column is called "target"
# Assume there is an amount for the target called "target_amt"

def col_check(sdf, colname, criteria, targetcolname, targetamtcolname=None):

# This part has to modify the target. The target maybe A / B.
sdf1 = sdf.filter(F.col(targetcolname)==1)
sdf0 = sdf.filter(F.col(targetcolname)==0)

# Count of Target = 1, when column is >= criteria
# p = percentage.
count1m = sdf1.filter(F.col(colname)>=criteria).count()
count1mp = count1m / (sdf1.count())*100

# Count of Target = 1, when column is >= criteria
# p = percentage.
count0m = sdf0.filter(F.col(colname)>=criteria).count()
count0mp = count0m / (sdf0.count())*100

print(f"For {colname} >= {criteria}:")
print(f"Class 0 Count: {count0m}. Percentage: {count0mp:.2f}%")
print(f"Class 1 Count: {count1m}. Percentage: {count1mp:.2f}%")

if targetamtcolname != None:
sum1m = sdf1.filter(F.col(colname)>=criteria).agg(F.sum(F.col(targetamtcolname))).collect()[0][0]
sum1mp = float(sum1m) / float((sdf1.agg(F.sum(F.col(targetamtcolname))).collect()[0][0]))*100
print(f"Class 1 Sum of Amt: {sum1m}. Percentage: {sum1mp:.2f}%")
else:
pass

# Count of Target = 1, when column is < criteria
# p = percentage.
count1l = sdf1.filter(F.col(colname)<criteria).count()
count1lp = count1l / (sdf1.count())*100

# Count of Target = 1, when column is >= criteria
# p = percentage.
count0l = sdf0.filter(F.col(colname)<criteria).count()
count0lp = count0l / (sdf0.count())*100

print(f"For {colname} <{criteria}:")
print(f"Class 0 Count: {count0l}. Percentage: {count0lp:.2f}%")
print(f"Class 1 Count: {count1l}. Percentage: {count1lp:.2f}%")

if targetamtcolname != None:
sum1l = sdf1.filter(F.col(colname)<criteria).agg(F.sum(F.col(targetamtcolname))).collect()[0][0]
sum1lp = float(sum1l) / float((sdf1.agg(F.sum(F.col(targetamtcolname))).collect()[0][0]))*100
print(f"Class 1 Sum of Amt: {sum1l}. Percentage: {sum1lp:.2f}%")
else:
pass


# If no targetamtcolname
col_check(sdf, 'balance', 500, 'target', targetamtcolname=None)

# If have targetamtcolname
col_check(sdf, 'balance', 500, 'target', 'target_amt')

In a class imbalance, when trying to provide bias for the model to learn more about the minority, it is important to find the column and criteria where the Class 0 and Class 1 have a significant differnce.

For example, if ‘balance’ > 500 has a Class 0 and Class 1 of 20%. We know that it is evenly distributed among the classes. However. if ‘balance’ > 500 is 20% for Class 0 and 99% of Class 1, we know that balance > 500 might help us naturally reduce the class imbalance. However, all these theories have to make business sense. It could also be that the dataset is problematic and we realize it thru this EDA.

WRAP UP

After doing basic EDA on the dataset, you may convert the findings via groupby to come out with deeper analysis on Pandas. Pandas / Matplotlib / Seaborn combination is a strong combination to better understand the data visually.

--

--