Spark Optimisation Techniques

Praveen Raj
inspiringbrilliance
7 min readFeb 11, 2020

Apache Spark is one of the most popular cluster computing frameworks for big data processing. However, running complex spark jobs that execute efficiently requires a good understanding of how spark works and various ways to optimise the jobs for better performance characteristics, depending on the data distribution and workload. Following are some of the techniques which would help you tune your Spark jobs for efficiency (CPU, network bandwidth, and memory)

Some of the common spark techniques using which you can tune your spark jobs for better performance,
1) Persist/Unpersist
2) Shuffle Partition
3) Push Down filters
4) BroadCast Joins

Persist/UnPersist

Spark persist is one of the interesting abilities of spark which stores the computed intermediate RDD around the cluster for much faster access when you query the next time. There is also support for persisting RDDs on disk or replicating across multiple nodes.
Knowing this simple concept in Spark would save several hours of extra computation.
Assume I have an initial dataset of size 1TB, I am doing some filtering and other operations over this initial dataset. Now the filtered data set doesn't contain the executed data, as you all know spark is lazy it does nothing while filtering and performing actions, it simply maintains the order of the operation (DAG) that needs to be executed while performing a transformation. Now, consider the case when this filtered_df is going to be used by several objects to compute different results.

filtered_df = filter_input_data(initial_data)

for obj in list_objects:
compute_df = compute_dataframe(input_df,obj)
percentage_df = calculate_percentage(compute_df)
export_as_csv(percentage_df)

Now what happens is after all computation while exporting the data frame as CSV, On every iteration, Transformation occurs for all the operations in order of the execution and stores the data as CSV. The result of filtered_df is not going to change for every iteration, but the problem is on every iteration the transformation occurs on filtered df which is going to be a time consuming one.

In this example, I ran my spark job with sample data. For every export, my job roughly took 1min to complete the execution. Assume, what if I run with GB’s of data, each iteration will recompute the filtered_df every time and it will take several hours to complete.

Now let me run the same code by using Persist,

filtered_df = filter_input_data(initial_data)

filter_df.persist()
for obj in list_objects:
compute_df = compute_dataframe(input_df,obj)
percentage_df = calculate_percentage(compute_df)
export_as_csv(percentage_df)
filter_df.unpersist()

Now what happens is filter_df is computed during the first iteration and then it is persisted in memory. From the next iteration instead of recomputing the filter_df, the precomputed value in memory will be used. This will save a lot of computational time. In the below example, during the first iteration it took around 2.5mins to do the computation and store the data to memory, From then on it took less than 30secs for every iteration since it is skipping the computation of filter_df by fetching from memory.

we can use various storage levels to Store Persisted RDDs in Apache Spark,

  1. MEMORY_ONLY: RDD is stored as a deserialised Java object in the JVM. If the size of RDD is greater than a memory, then it does not store some partitions in memory
  2. MEMORY_AND_DISK: RDD is stored as a deserialised Java object in the JVM. If the size is greater than memory, then it stores the remaining in the disk.
  3. MEMORY_ONLY_SER: RDD is stored as a serialised object in JVM.
  4. MEMORY_AND_DISK_SER: RDD is stored as a serialised object in JVM and Disk.
  5. DISK_ONLY: RDD is stored only in Disk.

Persist RDD’S/DataFrame’s that are expensive to recalculate. Persisting a very simple RDD/DataFrame’s is not going to make much of difference, the read and write time to disk/memory is going to be same as recomputing

Unpersist

Unpersist removes the stored data from memory and disk. Make sure you unpersist the data at the end of your spark job.

Shuffle Partitions

Shuffle partitions are partitions that are used when shuffling data for join or aggregations. Whenever we do operations like group by, Shuffling happens. In Shuffling, huge chunks of data get moved between partitions, this may happen either between partitions in the same machine or between different executors.
While dealing with RDD, you don't need to worry about the Shuffle partitions. Let's say an initial RDD is present in 8 partitions and we are doing group by over the RDD. The partition count remains the same even after doing the group by operation. But this is not the same case with data frame.

>>>df = spark.createDataFrame(
[('1', 'true'),('2', 'false'),
('1', 'true'),('2', 'false'),
('1', 'true'),('2', 'false'),
('1', 'true'),('2', 'false'),
('1', 'true'),('2', 'false'),
])
>>> df.rdd.getNumPartitions()
8
#Now performing a group by Operation
>>> group_df = df.groupBy("_1").count()>>> group_df.show()
+---+-----+
| _1|count|
+---+-----+
| 1| 5|
| 2| 5|
+---+-----+
>>> group_df.rdd.getNumPartitions()
200

In the above example, the shuffle partition count was 8, but after doing a groupBy the shuffle partition count shoots up to 200. This is because the sparks default shuffle partition for DataFrame is 200.
The spark shuffle partition count can be dynamically varied using the conf method in Spark sessionsparkSession.conf.set("spark.sql.shuffle.partitions",100)
or dynamically set while initializing through spark-submit operator
spark.sql.shuffle.partitions:100.

Tuning your spark configuration to a right shuffle partition count is very important, Let's say I have a very small dataset and I decide to do a groupBy with the default shuffle partition count 200. In this case, I might overkill my spark resources with too many partitions. In another case, I have a very huge dataset, and performing a groupBy with the default shuffle partition count. In this case, I might under utilize my spark resources.

Predicate Pushdown

In SQL, whenever you use a query that has both join and where condition, what happens is Join first happens across the entire data and then filtering happens based on where condition. What will happen if spark behaves the same way as SQL does, for a very huge dataset, the join would take several hours of computation to join the dataset since it is happening over the unfiltered dataset, after which again it takes several hours to filter using the where condition.

Predicate pushdown, the name itself is self-explanatory, Predicate is generally a where condition which will return True or False. During the Map phase what spark does is, it pushes down the predicate conditions directly to the database, filters the data at the database level itself using the predicate conditions, hence reducing the data retrieved from the database and enhances the query performance. Since the filtering is happening at the data store itself, the querying is very fast and also since filtering has happened already it avoids transferring unfiltered data over the network and now only the filtered data is stored in the memory.
We can use the explain method to see the physical plan of the dataframe whether predicate pushdown is used or not. Predicates need to be casted to the corresponding data type, if not then predicates don't work.

>>> df = spark.read.parquet("file1").filter((F.col("date") >= "2019-08-01") & (F.col("date") <= "2019-09-01"))>>> df.explain()
== Physical Plan ==
*(1) Project [ id#236, date#237, day_of_week#238, hour_of_day#239, format#242, media_owner#243, environment#244, city#245, county#246, country#247, district#248, postcode_area#249, postcode_district#250, postcode_sector#251, tvregion#252]
+- *(1) Filter ((isnotnull(date#237) && (cast(date#237 as string) >= 2019-08-01)) && (cast(date#237 as string) <= 2019-09-01))
+- *(1) FileScan parquet [id#236,date#237,day_of_week#238,hour_of_day#239,format#242,media_owner#243,environment#244,city#245,county#246,country#247,district#248,postcode_area#249,postcode_district#250,postcode_sector#251,tvregion#252] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/praveenr/workspace/ada-data-pipeline-jobs/campaign-measurement/src/..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:string,date:date,day_of_week:int,hour_of_day:int,impressions:bigint,a..

In the above example, I am trying to filter a dataset based on the time frame, pushed filters will display all the predicates that need to be performed over the dataset, in this example since DateTime is not properly casted greater-than and lesser than predicates are not pushed down to dataset.

>>> df = spark.read.parquet("file1").filter((F.col("date") >= datetime.strptime("2019-08-01", "%Y-%m-%d").date()) & (F.col("date") <= datetime.strptime("2019-09-01", "%Y-%m-%d").date()))>>> df.explain()
== Physical Plan ==
*(1) Project [ id#272, date#273, day_of_week#274, hour_of_day#275, format#278, media_owner#279, environment#280, city#281, county#282, country#283, district#284, postcode_area#285, postcode_district#286, postcode_sector#287, tvregion#288]
+- *(1) Filter ((isnotnull(date#273) && (date#273 >= 18109)) && (date#273 <= 18140))
+- *(1) FileScan parquet [id#272,date#273,day_of_week#274,hour_of_day#275,format#278,media_owner#279,environment#280,city#281,county#282,country#283,district#284,postcode_area#285,postcode_district#286,postcode_sector#287,tvregion#288] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/praveenr/workspace/ada-data-pipeline-jobs/campaign-measurement/src/..., PartitionFilters: [], PushedFilters: [GreaterThanOrEqual(date,2019-08-01), LessThanOrEqual(date,2019-09-01)], ReadSchema: struct<frame_id:string,id:string,date:date,day_of_week:int,hour_of_day:int,impressions:bigint,a...

In the above example, the date is properly type casted to DateTime format, now in the explain you could see the predicates are pushed down.

BroadCast Joins

When we do a join with two large dataset’s what happens in the backend is, huge loads of data gets shuffled between partitions in the same cluster and also get shuffled between partitions of different executors.

Broadcast joins are used whenever we need to join a larger dataset with a smaller dataset. When we use broadcast join spark broadcasts the smaller dataset to all nodes in the cluster since the data to be joined is available in every cluster nodes, spark can do a join without any shuffling. Using this broadcast join you can avoid sending huge loads of data over the network and shuffling. Using the explain method we can validate whether the data frame is broadcasted or not. The below example illustrated how broadcast join is done.

>>> df1 = spark.read.parquet("file1")>>> df2 = spark.read.parquet("file2")
>>> broadcast_join = df1.join(F.broadcast(df2),"id")
>>> broadcast_join.explain()
== Physical Plan ==
*(2) Project [id#458, city#430]
+- *(2) BroadcastHashJoin [id#458], [id#414], Inner, BuildRight
:- *(2) Project [id#421 AS id#458, city#430]
: +- *(2) Filter isnotnull(id#421)
: +- *(2) FileScan parquet [id#421,city#430] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/src/..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:string,city:string>
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]))
+- *(1) Project [id#346 AS id#414]
+- *(1) Filter isnotnull(id#346)
+- *(1) FileScan parquet [id#346] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/src/..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:string>

This post covers some of the basic factors involved in creating efficient Spark jobs. Following the above techniques will definitely solve most of the common spark issues.

--

--