Oversampling and Undersampling with PySpark

Jun Wan
2 min readFeb 9, 2020

--

In Machine Learning, when dealing with Classification problem with imbalanced training dataset, oversampling and undersampling are two easy and often effective ways to improve the outcome.

What is Imbalanced dataset

The number of examples in one class in your dataset is significantly greater than the examples in the other class. This happens in many areas, like in fraud detection dataset, you have much more normal transactions than the fraudulent ones. Or medical diagnostics where normal examples outnumber diseased examples.

Such dataset can cause your model to blindly predicting dominant class since it can achieve good accuracy anyway. Ways to combat that include oversample the minority class, undersample majority class, add class weight, change algorithm, generate synthetic samples, etc.

This article shows how to oversample or undersample in PySpark Dataframe.

PySpark Dataframe Example

Let’s set up a simple PySpark example:

# code block 1
from pyspark.sql.functions import col, explode, array, lit
df = spark.createDataFrame([['a',1],['b',1],['c',1],['d',1], ['e',1], ['f',1], ['x', 0], ['y', 0]], ['feature', 'label'])
df.show()
major_df = df.filter(col("label") == 1)
minor_df = df.filter(col("label") == 0)
ratio = int(major_df.count()/minor_df.count())
print("ratio: {}".format(ratio))

The output:

+-------+-----+
|feature|label|
+-------+-----+
| a| 1|
| b| 1|
| c| 1|
| d| 1|
| e| 1|
| f| 1|
| x| 0|
| y| 0|
+-------+-----+
ratio: 3

class (label) 1 has 6 examples, while class 0 has only 2 examples. We can undersample class 1, or oversample class 0. The ratio is 3.

Oversampling

The idea of oversampling, is to duplicate the samples from under-represented class, to inflate the numbers till it reaches the same level as the dominant class. Here is how to do it in PySpark Dataframe:

... skipped from code block 1 ...a = range(ratio)# duplicate the minority rows
oversampled_df = minor_df.withColumn("dummy", explode(array([lit(x) for x in a]))).drop('dummy')
# combine both oversampled minority rows and previous majority rows combined_df = major_df.unionAll(oversampled_df)
combined_df.show()

The output:

+-------+-----+
|feature|label|
+-------+-----+
| a| 1|
| b| 1|
| c| 1|
| d| 1|
| e| 1|
| f| 1|
| x| 0|
| x| 0|
| x| 0|
| y| 0|
| y| 0|
| y| 0|
+-------+-----+

In the above code, we leverage the Spark’s explode function. First we create a new dummy column containing a literal array of numbers, with the array size being the multiplier we want to apply to minority class rows. Then theexplode function create a new row for each element in the array. Last we drop the dummy column.

Undersampling

Undersampling is opposite to oversampling, instead of make duplicates of minority class, it cuts down the size of majority class. There is a builtin sample function in PySpark to do that:

... skip from code block 1 ...sampled_majority_df = major_df.sample(False, 1/ratio)
combined_df_2 = sampled_majority_df.unionAll(minor_df)
combined_df_2.show()

The output:

+-------+-----+
|feature|label|
+-------+-----+
| a| 1|
| b| 1|
| x| 0|
| y| 0|
+-------+-----+

Undersampling reduces your overall training dataset size, so you should try this approach when your original dataset is fairly large.

--

--

Jun Wan

Software Builder in Boston Area. Occasionally write about technologies and life.