Spark Under the Hood: randomSplit() and sample() Inner Workings

Meltem Tutar
Udemy Tech Blog
Published in
6 min readApr 30, 2020

I recently joined the Recommendations Team at Udemy as a data scientist and am learning a lot about data platforms and tools. Udemy has millions of users, and there is a diverse range of big data, including course consumptions, enrollments, reviews, page interactions, and more. The goal of our team is to utilize this data and cutting edge approaches to provide quality recommendations to our users. My first project was to work on improving Matrix Factorization based recommendations and I had to learn PySpark, the Apache Spark Python API for big data processing. After taking a few Udemy courses I was ready to go 😉.

Initially, I must admit I was treating most functions as a black box and using them like I would use Python functions. Then I started encountering unexpected behavior after using the randomSplit() method, like seeing duplicate or disappearing rows and unexpected results from joins. This led me to examine how randomSplit() works and what was causing anomalies. While doing so, I also noticed other unexpected behavior such as sample() resulting in different samples despite being applied on the same data frame and using the same seed. These behaviors showed me the importance of understanding implementations in Spark. In this article, I summarize my findings, first by discussing the inconsistencies I encountered, then explaining the randomSplit() implementation, and finally outlining methods to avoid these issues.

Inconsistencies Encountered

Despite using the same seed, the splits from randomSplit() had non-deterministic behavior; inconsistencies weren’t happening on every run and subsequent runs produced different behavior. I’ll show an example of code in Apache Zeppelin to illustrate this. The sum of the number of data points in my splits differed from the total number of data points I originally had. See Figure 1 for an example where splits were causing varying amounts of data points to disappear and a subsequent run resulted in a valid split. Note that Spark re-evaluates transformations internally for every action statement. In other words, every time count() is called, Spark is applying randomSplit() again.

Figure 1: Example where randomSplit() resulted in splits with missing values and a subsequent run resulted in a valid split. Splits display non-deterministic behavior.

After researching the implementation of randomSplit(), I found that it relies on sorting within partitions and sampling, so I investigated the sample() method as well. Looking at Figure 2, you can see that sampling, despite using the same seed, may result in different samples.

Figure 2: Example of inconsistencies with sample()

The non-deterministic results of these methods may also have other repercussions. As I mentioned, Spark is constantly re-fetching data frames behind the scenes in accordance with the transformations and actions you’ve specified. So having a different data frame change upon refetch is problematic. In my case, there were unexpected results from cross joins. For example, a user existed in a data frame and upon cross joining with another data frame, the user’s data would disappear. This is because Spark internally re-computes the splits with each action. So initially when I was checking for the user, it was in one of the splits, but after calling another action statement after the cross join, the user was not in the split and in fact had disappeared.

These anomalies may have to do with the underlying data source of the data frame, in my case I was using a SQL query to load the data from a Hive database. Though, I saw unanswered posts on StackOverflow and issues reported in Jira where others were experiencing issues with randomSplit() with other underlying data sources. I will continue by explaining the implementation of randomSplit() to guide us in understanding what was causing these anomalies.

Spark Implementation of randomSplit()

Signature Function

The signature function of randomSplit() includes a weight list and a seed specification. The weight list is to specify the number of splits and percentage (approximate) in each and the seed is for reproducibility. The ratio is approximate due to the nature of how it is calculated.

For example, the following code in Figure 3 would split df into two data frames, train_df being 80% and test_df being 20% of the original data frame. By using the same value for random seed, we are expecting that the same data points are in the same split if we were to re-run the script or Spark internally rebuilds the splits.

Figure 3: randomSplit() signature function example

Under the Hood

The following process is repeated to generate each split data frame: partitioning, sorting within partitions, and Bernoulli sampling. If the original data frame is not cached then the data will be re-fetched, re-partitioned, and re-sorted for each split calculation. This is the source of potential anomalies. In summary, randomSplit() is equivalent to performing sample() for each split with the percentage to sample changing with the split being performed. This is evident if you examine the source code for randomSplit() in PySpark³. This blog⁴ also provides some more information and visuals on how randomSplit() is implemented.

Let’s walk through an example. Figure 4 is a diagram of the sample() for each split, starting with the 0.80 split.

Figure 4: Process of generating the 0.8 split. Identical to sample() implementation.

Spark utilizes Bernoulli sampling, which can be summarized as generating random numbers for an item (data point) and accepting it into a split if the generated number falls within a certain range, determined by the split ratio. For a 0.8 split data frame, the acceptance range for the Bernoulli cell sampler would be [0.0,0.80].

The same sampling process is followed for the 0.20 split in Figure 5, with just the boundaries of acceptance changing to [0.80, 1.0].

Figure 5: Process of generating the 0.2 split. Identical to sample(). implementation. Partition contents remain constant and sorting order is preserved ensuring a valid split.

The data frame is re-fetched, partitioned, and sorted within partitions again. You can see in the example that RDD partitions are idempotent. Which means that the data points in each partition in Figure 4, remain in the same partition in Figure 5. For example, points b and c are in Partition 1 in both Figure 4 and 5. Additionally, the seed associated with each partition always remains constant, and the order within partitions is identical. All three of these points are fundamental to both sample() and randomSplit(). Ensuring that the same sample is produced with the same seed in the former, and guaranteeing no duplicates or disappearing data points in the latter.

Solutions To Avoiding Inconsistencies

Fixing these issues lies in ensuring that RDD partitions and sorting order are idempotent. Any one of the following three methods ensure this and can be applied: 1) caching the data frame before operations 2) repartitioning by a column or a set of columns, and 3) using aggregate functions⁵. An example of each method is shown in Figure 6.

Figure 6: Three different methods to avoid inconsistencies in randomSplit() and sample()

Caching the original data frame leads to partition content being held in memory. So instead of re-fetching data, partitioning and sorting, Spark continues operations using the partitioned data in memory. Note that cache() is an alias for persist(pyspark.StorageLevel.memory_only) which may not be ideal if you have memory limitations. Instead, you can consider using persist(pyspark.StorageLevel.memory_and_disk_only). If there is no memory or disk space available, Spark will re-fetch and partition data from scratch, so it may be wise to monitor this from the Spark Web UI. Caching is the solution I chose in my case.

Summary and Key Takeaways

Moral of the story is: if unexpected behavior is happening in Spark, you just need to dig a bit deeper! Here is a summary of all the key points of this article:

  • randomSplit() is equivalent to applying sample() on your data frame multiple times, with each sample re-fetching, partitioning, and sorting your data frame within partitions.
  • The data distribution across partitions and sorting order is important for both randomSplit() and sample(). If either change upon data re-fetch, there may be duplicates or missing values across splits and the same sample using the same seed may produce different results.
  • These inconsistencies may not happen on every run, but to eliminate them completely, persist (aka cache) your data frame, repartition on a column(s), or apply aggregate functions such as groupBy.

Good luck in all of your Spark endeavors, and I hope this article is a good starting point to understanding the inner workings of Spark. I’d appreciate any and all comments you may have!

--

--