Scaling Iterative Algorithms in Spark

Pruthvi Raj
The Startup
Published in
6 min readAug 5, 2020

Iterative algorithms are widely implemented in machine learning, connected components, page rank, etc. These algorithms increase in complexity with iterations, size of data at each iteration and making it fault-tolerant at every iteration is tricky. In this article, I would elaborate few of considerations in spark to work with these challenges. We have used Spark in implementing few iterative algorithms like building connected components, traversing large connected components etc. Below are from my experience while working at Walmart labs building connected components for 60 Billion nodes of customer identities.

Number of iterations is never predetermined, always have terminate condition 😀

Types of Iterative algorithms

Converging data: Here we see the amount of data decreases with every iteration, i.e. we start 1st iteration with huge datasets and the data size decreases with an increase in iterations. The main challenge would be handling the huge datasets in the first few iterations and once a dataset has reduced significantly it would be easy to handle further iterations until end condition.

Diverging data: Amount of data increases at every iteration and sometimes it could explode faster and make it impossible to proceed further. We need to have constraints like limitations in the number of iterations, starting data size, size of compute power etc. to make these algorithms work.

Similar data: At every iteration, we would have more or less the same data and this kind of algorithm would be very easy to handle.

Incremental data: At every iteration, we may have some new data to be added, especially in ML we might have new training datasets with periodic intervals.

Challenges

  1. RDD Lineage: One of the common ways to keep any system fault-tolerant is to keep replicas of data at different places so that if one node goes down, we would have a replica which would help until the node has recovered. But spark does not maintain replicas of data instead it maintains the lineage graph of transformations done on the data in the driver. So this lineage graph would be helpful if any piece of data is missing, it can build it back using lineage graph, hence spark is fault-tolerant. As Lineage graph grows big it becomes tough to build back the data as the number of iterations increase.
  2. Memory and Disk I/O: In Spark RDD’s are immutable, so for every iteration, we would be creating a new copy of transformed data ( new RDD) which would increase the usage of Memory and Disk. As iterations increase the disk/memory usage on executors would increase which could lead to slowing down due to lack of memory and wait for GC to perform the cleanup. In some cases, heap memory wouldn’t be enough and kill the task
  3. Task Size: In a few of the cases, there could be few of tasks which might not fit on single executors or single task takes much more time than rest of the tasks which might lead to a bottleneck

Tips to overcome the above challenges

  1. Keeping a large lineage graph in memory and in case a node fails, it would be time-consuming to rebuild the lost datasets. In such cases, we can use caching or checkpointing of datasets at every N iterations. It would save the calculated RDD at every N-th iteration ( caching would store in memory or disk of executors, checkpointing uses HDFS, we need to decide based on our need as speed would differ for each of them). In case there is any failure, RDD is calculated back from the last checkpoint/cached. Instead of using the above two methods, you can also create a temporary table and save the calculated dataset partitioned by iteration. This would be helpful in case the spark job fails, we can restart from last Nth iteration and advantage of saving to a temporary table is getting rid of RDD lineage graph until that iteration and start fresh lineage graph from that point. As RDD lineage graph grows big in iterative algorithms, and we need to build hybrid solutions using caching, checkpointing (see reference [2]) and temporary tables for different use-cases.
  2. Same as above storing into the temporary table and reading back from the temporary table can get rid of lineage graph and clean up memory and disk of previous RDD’s. This writes and reads add overhead, but it would give a huge advantage when handling with large datasets. Especially in converging datasets, we might need to do this process only for the first few iterations and use caching when the datasets have become small with iterations. Saving into temporary table as checkpoint looks trivial but it doesn’t just act as checkpoint. As we are getting rid of lineage graph history by doing this at periodic iterations. This would reduce risk of job failing and reduce time in building it back of lost data.
  3. Handling diverging data is tricky because the size of each task would keep increasing with iterations and take much more time for each executor. So we need a factor to figure out the number of tasks in ( i + 1) iteration compared to i-th iteration such that task size remains the same. For example, let say the number of tasks in the i-th iteration is 100 and each task is processing around 100 MB of data. In i+1 iteration size of each task increased to 150MB, we can shuffle these 100 tasks to 150 tasks and keep 100MB per task. So in diverging datasets, we need to increase the number of tasks by re-partitioning and changing shuffle partitions based on iteration.
  4. In cases where spark task size is huge, try increasing the executor memory to fit the task size. And if we need to perform joins on skewed datasets where 10% of tasks takes 90% of runtime and 90% of tasks complete in 10% time. It is suggested to handle these tasks separately by running them in as two different queries. We need to identify reason for large tasks and if we can separate them into two groups i.e. small and large tasks. In the1st query we would process 90% of the tasks as there is no hurdle to process them and it would take 10% of time like earlier. In another query would process large tasks (10% of tasks) using broadcast join as number of such tasks are fewer and also avoid shuffles of data.
    Example: Lets say we have a Table A and Table B. Table A is population data with columns user_id, name, city, state. Table B is whatsapp groups data with columns user_id, group_id. If we are trying to find out top 5 cities with most number of whatsapp groups used. In this example there could be corner cases like cities with large population could be large task, users with many groups could lead to large tasks. To solve these corner cases, join between these tables can be done in two queries. we can filter out large users with many groups ( lets say threshold of 1000 groups per user) and treat them as large tasks. And perform joins separately for large users using broadcast join as number of large users would be few compared to total data. Similarly for rest of users perform shuffle join and combine both results and aggregate by cities to find top 5 cities.

Please add any queries in comments, here are some relevant links for more details

References :

[3] An interesting research paper [link]
[4] Refer spark documentation

--

--