Caching Spark DataFrame — How & When

Nofar Mishraki
Pecan Tech Blog
Published in
3 min readSep 26, 2020

Let’s begin with the most important point — using the caching feature in Spark is super important. If we don’t use caching in the right places (or maybe don’t use it at all), we can cause severe performance issues.

Sometimes, even if we really understand a concept, it’s challenging to be aware of it with every line of code we write. This is especially true when we are talking about code that is executed in lazy evaluation mode.

It’s vital for us to “develop” the instinct of caching in the right places.
In this article, I’ll share some tips about how to develop this instinct.

How to cache?

Python API

df.cache()

Or:

df.persist()

In contrast to using cache(), we can specify the storage level parameter when using persist().

Storage level refers to the destination location of the cached DataFrame. When using persist(), we can choose between the different options available: MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, etc.

The default storage level for both cache() and persist() for the DataFrame is MEMORY_AND_DISK (Spark 2.4.5) —The DataFrame will be cached in the memory if possible; otherwise it’ll be cached on the disk.

Spark SQL

df = spark.read.parquet("s3://.....")df.createOrReplaceTempView("df")spark.sql("CACHE TABLE df")

The caching operation when using Python API is lazy while in Spark SQL, it is eager; when we use caching in Python API, the caching will only occur after the execution of an “action” on this DataFrame.

When to cache?

If you’re executing multiple actions on the same DataFrame then cache it.

Let’s look at an example:

df = spark.read.parquet("s3://.....")df.createOrReplaceTempView("df")df = spark.sql("SELECT id, avg(time), count(transaction_id) FROM df GROUP BY id")for column in df.columns:
n_unique_values = df.select(column).distinct().count()
if n_unique_values == 1:
print(column)

Every time the following line is executed (in this case 3 times), Spark reads the Parquet file and executes the query.

n_unique_values = df.select(column).distinct().count()

In the following example, I only added the caching:

df = spark.read.parquet("s3://.....")df.createOrReplaceTempView("df")df = spark.sql("SELECT id, avg(time), count(transaction_id) FROM df GROUP BY id")df.persist() <---------- THE ONLY CHANGEfor column in df.columns:
n_unique_values = df.select(column).count().distinct()
if n_unique_values == 1:
print(column)

Now, Spark will read the Parquet, execute the query only once, and then cache it. Then the code in the loop will use the cached, pre-calculated DataFrame.

Imagine that you are working with a lot of data, and you run a series of queries and actions on it without using caching. It runs again and again without you even noticing. This can add hours to the job running time or even make the job fail.

The best way to make sure everything has run as expected is to look at the execution plan.

You can see in the following execution plan the keywords InMemoryTableScan and InMemoryRelation, which indicate that we are working on a cached DataFrame.

== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[finalmerge_count(merge count#657L) AS count(1)#638L], output=[count#639L])
+- Exchange SinglePartition, [id=#1695]
+- *(1) HashAggregate(keys=[], functions=[partial_count(1) AS count#657L], output=[count#657L])
+- *(1) HashAggregate(keys=[id#593], functions=[], output=[])
+- InMemoryTableScan [id#593]
+- InMemoryRelation [id#593, avg(time)#604, count(transaction_id)#605L], StorageLevel(disk, memory, 1 replicas)
+- *(2) HashAggregate(keys=[id#593], functions=[finalmerge_avg(merge sum#611, count#612L) AS avg(time#594L)#602, finalmerge_count(merge count#614L) AS count(transaction_id#595L)#603L], output=[id#593, avg(time)#604, count(transaction_id)#605L])
+- Exchange hashpartitioning(id#593, 200), [id=#1618]
+- *(1) HashAggregate(keys=[id#593], functions=[partial_avg(time#594L) AS (sum#611, count#612L), partial_count(transaction_id#595L) AS count#614L], output=[id#593, sum#611, count#612L, count#614L])
+- *(1) FileScan parquet [id#593,time#594L,transaction_id#595L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[s3://my_bucket/example], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:string,time:bigint,transaction_id:bigint>

Here is the execution plan without caching:

== Physical Plan ==
*(4) HashAggregate(keys=[], functions=[finalmerge_count(merge count#785L) AS count(1)#781L], output=[count#782L])
+- Exchange SinglePartition, [id=#2030]
+- *(3) HashAggregate(keys=[], functions=[partial_count(1) AS count#785L], output=[count#785L])
+- *(3) HashAggregate(keys=[avg(time)#764], functions=[], output=[])
+- Exchange hashpartitioning(avg(time)#764, 200), [id=#2025]
+- *(2) HashAggregate(keys=[avg(time)#764], functions=[], output=[avg(time)#764])
+- *(2) HashAggregate(keys=[id#753], functions=[finalmerge_avg(merge sum#788, count#789L) AS avg(time#754L)#762], output=[avg(time)#764])
+- Exchange hashpartitioning(id#753, 200), [id=#2020]
+- *(1) HashAggregate(keys=[id#753], functions=[partial_avg(time#754L) AS (sum#788, count#789L)], output=[id#753, sum#788, count#789L])
+- *(1) FileScan parquet [id#753,time#754L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[s3://my_bucket/example], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:string,time:bigint>

In addition, you can see the cached table in Storage view:

Storage View — Spark UI

Conclusion

I decided to write about Spark caching because it’s a fundamental concept that is important to remember when we write Spark jobs.

Caching is not always so intuitive; sometimes we are tackling performance issues because of unnecessary computations that we are not aware of.

My goal is to share my thought process and some tools I use to deal with Spark job optimization in the context of caching.

I hope you’ll find this article helpful.

--

--