Handling Skewed Data in Apache Spark

Zaid Erikat
8 min readApr 30, 2023

--

Image by Beth Macdonald on unsplash.com

What is Spark?

Spark is a popular distributed computing engine for processing large datasets. Join operations are common in Spark and can be used to combine data from multiple sources. However, skewed data can pose a significant challenge when performing join operations in Spark.

What is Skewed Data?

Skewed data is a situation where one or more keys have a disproportionately large number of values compared to other keys. This can result in a few partitions being significantly larger than others. When performing join operations, Spark partitions data across the cluster and performs operations on the partitions. If some partitions are significantly larger than others, the processing time for those partitions will be longer, leading to slower performance.

Issues caused by Skewed Data

Skewed data can cause several issues in Spark, particularly when it comes to data processing and analysis. Here are some of the main issues:

  • Imbalanced workload: When data is skewed, some partitions may have significantly more data than others, causing an imbalance in the workload. This can lead to some tasks taking much longer to complete, which can slow down the entire job.
  • Out of memory errors: In some cases, skewed data can cause out of memory errors because a single partition may contain too much data to fit into memory. This can be particularly problematic if the data is being cached in memory for iterative processing.
  • Uneven resource usage: If one or more partitions contain significantly more data than others, they may consume a disproportionate amount of resources (such as CPU or memory), leading to inefficient resource utilization.
  • Slow processing times: Skewed data can cause slower processing times, particularly for operations like joins and aggregations, which require shuffling and data movement between partitions.
  • Job failures: In extreme cases, skewed data can cause job failures, especially if the skewed partitions cause out-of-memory errors or lead to long-running tasks that exceed the maximum allotted time.

To mitigate these issues, it is important to apply techniques such as salting, co-partitioning, and skew join optimization to handle skewed data. Additionally, it may be necessary to tune Spark configuration parameters (such as the number of partitions) or allocate more resources (such as memory or CPU) to handle skewed data more efficiently.

Techniques for handling Skewed data

To handle skewed data in Spark join operations, there are several techniques that can be used:

Salting

Salting is a technique that involves adding a random prefix to the key of each record to distribute the data uniformly across the partitions. This helps to ensure that the data is evenly distributed, reducing the likelihood of skewed data. Salting is a common technique used in Spark to handle skewed data.

Example

Assume we have two tables, sales and inventory, with the following schema:

sales (sale_id: int, product_id: int, sale_amount: double)
inventory (product_id: int, units_sold: int)

We want to join these two tables to compute the total revenue for each product. However, assume that the sales table is skewed, with a few product_ids having significantly more sales than the others. This can lead to uneven partitioning of the data and slow performance during the join operation.
To use the salting technique, we add a random prefix or suffix to the join key. This ensures that the skewed data is distributed more evenly across the partitions, improving the performance of the join operation.
Here's an example of how we can implement salting in Spark:

import org.apache.spark.sql.functions._
import scala.util.Random
val spark: SparkSession = SparkSession.builder
.appName("SaltingExample")
.getOrCreate()
val salesDF = spark.read.csv("path_to_sales.csv").toDF("sale_id", "product_id", "sale_amount")
val inventoryDF = spark.read.csv("path_to_inventory.csv").toDF("product_id", "units_sold")
val numSaltPartitions = 10 // number of partitions to use for the salting// define a salting function to add a random prefix to the join key
def saltedProductID(productID: Int): String = {
val random = new Random(productID)
val prefix = random.nextInt(numSaltPartitions)
s"$prefix:$productID"
}
// add a salted product_id column to the sales data
val saltedSalesDF = salesDF.withColumn("salted_product_id", udf(saltedProductID _)(col("product_id")))
// add a salted product_id column to the inventory data
val saltedInventoryDF = inventoryDF.withColumn("salted_product_id", udf(saltedProductID _)(col("product_id")))
// join the salted sales and inventory data on the salted_product_id column
val joinResult = saltedSalesDF.join(saltedInventoryDF, Seq("salted_product_id"))
.groupBy("product_id")
.agg(sum("sale_amount").as("total_sales"))
joinResult.show()

In this example, we define a salting function that adds a random prefix to the product_id, and use it to add a salted_product_id column to both the sales and inventory data. We then join the two tables on the salted_product_id column, which ensures that the skewed data is distributed more evenly across the partitions. Finally, we group the join result by product_id and compute the total sales for each product.

By using the salting technique, we can handle skewed data in Spark and improve the performance of join operations. However, the choice of the salting function and the number of partitions used can impact the performance of the join operation, and it may require some tuning to find the optimal configuration.

Bucketing

Bucketing is a technique used to distribute data uniformly across the partitions. It involves partitioning data based on the values of a key into a fixed number of buckets. This ensures that the data is evenly distributed, reducing the likelihood of skewed data.

Example

Take the last dataset we used in the Sales & Inventory tables

To use the bucketing technique, we first bucket the data based on the product_id column, ensuring that each bucket has roughly the same amount of data. We can then use the bucketed data to perform a join operation more efficiently.
Here’s an example of how we can implement bucketing in Spark:

import org.apache.spark.sql.functions._
val spark: SparkSession = SparkSession.builder
.appName("BucketingExample")
.getOrCreate()
val salesDF = spark.read.csv("path_to_sales.csv").toDF("sale_id", "product_id", "sale_amount")
val inventoryDF = spark.read.csv("path_to_inventory.csv").toDF("product_id", "units_sold")
val numBuckets = 10 // number of buckets to use for the bucketing// bucket the sales data based on the product_id column
val bucketedSalesDF = salesDF.repartition(numBuckets, col("product_id"))
// bucket the inventory data based on the product_id column
val bucketedInventoryDF = inventoryDF.repartition(numBuckets, col("product_id"))
// join the bucketed sales and inventory data on the product_id column
val joinResult = bucketedSalesDF.join(bucketedInventoryDF, Seq("product_id"))
.groupBy("product_id")
.agg(sum("sale_amount").as("total_sales"))
joinResult.show()

In this example, we use the repartition function to bucket the data based on the product_id column. We then join the bucketed sales and inventory data on the product_id column, which ensures that each bucket has roughly the same amount of data. Finally, we group the join result by product_id and compute the total sales for each product.

Broadcast Join

Broadcast join is a technique used when joining a small table with a large table. The small table is broadcast to all the partitions of the large table, reducing the amount of data that needs to be shuffled. This can help to reduce the impact of skewed data.

Example

Take the same dataset from the last example.

To use the broadcast join technique, we first identify the smaller table, which in this case is the product table. We then broadcast this table to all the worker nodes, ensuring that it can fit into memory. We can then use the broadcasted table to perform a join with the sales table more efficiently.

Here's an example of how we can implement broadcast join in Spark:

import org.apache.spark.sql.functions._
val spark: SparkSession = SparkSession.builder
.appName("BroadcastJoinExample")
.getOrCreate()
val salesDF = spark.read.csv("path_to_sales.csv").toDF("sale_id", "product_id", "sale_amount")
val productDF = spark.read.csv("path_to_product.csv").toDF("product_id", "product_name")
// identify the smaller table
val smallTable = productDF.select("product_id", "product_name")
// broadcast the small table to all the worker nodes
val broadcastTable = broadcast(smallTable)
// join the broadcasted table with the sales table on the product_id column
val joinResult = salesDF.join(broadcastTable, Seq("product_id"))
.groupBy("product_id", "product_name")
.agg(sum("sale_amount").as("total_sales"))
joinResult.show()

In this example, we use the broadcast function to broadcast the small product table to all the worker nodes. We then join the broadcasted table with the sales table on the product_id column, which ensures that each worker node has all the necessary data for the join operation. Finally, we group the join result by product_id and product_name and compute the total sales for each product.

Sampling

Sampling is a technique used to select a representative subset of the data for processing. This can help to reduce the amount of data processed during join operations, which can help to reduce the impact of skewed data.

Example

Take the same dataset from the last example.

To use the sampling join technique, we first identify the skewed partition of the sales table, which contains the product_ids with significantly more sales. We then sample a subset of this partition and use it to perform a join with the product table. This can help reduce the amount of data being processed during the join operation and improve performance.

Here's an example of how we can implement sampling join in Spark:

import org.apache.spark.sql.functions._
val spark: SparkSession = SparkSession.builder
.appName("SamplingJoinExample")
.getOrCreate()
val salesDF = spark.read.csv("path_to_sales.csv").toDF("sale_id", "product_id", "sale_amount")
val productDF = spark.read.csv("path_to_product.csv").toDF("product_id", "product_name")
// identify the skewed partition of the sales table
val salesSkewedPartition = salesDF.filter($"product_id".isin(1, 2, 3))
// sample a subset of the skewed partition
val sampleDF = salesSkewedPartition.sample(false, 0.1, 42)
// join the sample with the product table
val joinResult = sampleDF.join(productDF, Seq("product_id"))
.groupBy("product_id", "product_name")
.agg(sum("sale_amount").as("total_sales"))
joinResult.show()

In this example, we first identify the skewed partition of the sales table by filtering for product_ids 1, 2, and 3. We then sample a subset of this partition using the sample function, which randomly selects 10% of the data with a seed of 42. We then join the sample with the product table on the product_id column, which ensures that each worker node has all the necessary data for the join operation. Finally, we group the join result by product_id and product_name and compute the total sales for each product.

Co-partitioning

Co-partitioning involves partitioning two tables using the same partitioning scheme to ensure that the data is distributed uniformly across the partitions. This can help to reduce the impact of skewed data during join operations.

Example:

Assume we have two tables, A and B, with the following schema:
Table A:

+---+-------+
|id |value_A|
+---+-------+
| 1 | 10 |
| 2 | 20 |
| 3 | 30 |
| 4 | 40 |
| 5 | 50 |
+---+-------+

Table B:

+---+-------+
|id |value_B|
+---+-------+
| 1 | 100 |
| 2 | 200 |
| 3 | 300 |
| 4 | 400 |
| 5 | 500 |
+---+-------+

We want to join these two tables based on the id column. However, assume that the id column in Table A is skewed, with the values 1 and 2 having a significantly larger number of records than the other values. This can lead to uneven partitioning of the data and slow performance during the join operation.

To handle this skew, we can use co-partitioning to partition both tables using the same partitioning scheme. In this example, we will partition the data using the hash partitioning scheme based on the id column:

val numPartitions = 4
val tableA = spark.read.csv("path_to_table_A.csv").toDF("id", "value_A")
val tableB = spark.read.csv("path_to_table_B.csv").toDF("id", "value_B")
val partitioner = new HashPartitioner(numPartitions)val partitionedTableA = tableA.repartition(numPartitions, $"id")(partitioner)
val partitionedTableB = tableB.repartition(numPartitions, $"id")(partitioner)
val joinedTable = partitionedTableA.join(partitionedTableB, Seq("id"), "inner")

In this example, we partition both tables using the same partitioning scheme based on the id column. We specify the number of partitions as numPartitions and use the HashPartitioner to partition the data.

Once both tables are partitioned, we can perform the join operation using the join method in Spark. By using co-partitioning, we ensure that the data is evenly distributed across the partitions, reducing the impact of skewed data during the join operation.

Conclusion

In conclusion, skewed data can pose a significant challenge when performing join operations in Spark. However, there are several techniques that can be used to handle skewed data, including salting, bucketing, broadcast join, sampling, join reordering, skew join optimization, and co-partitioning. By using these techniques, it is possible to improve the performance of join operations in Spark and ensure that the data is evenly distributed across the partitions.

--

--