Computing global rank of a row in a DataFrame with Spark SQL

Sushil Kumar
The Startup
Published in
4 min readSep 14, 2019
Source : https://tryengineering.org/teacher/fun-sorting/

In my current role I’m working on implementing predictive modelling for customer personalization problem. We heavily utilize Apache Spark both for our ML jobs (Spark MLlib) and other non-ML batch jobs (Spark SQL). More often than not a situation arise where I have to globally rank each row in a DataFrame based on order in certain column. In ML world where we deal with sampled data sets, using Window’ed rank() function without any partitionBy clause works fine, as data sets are small (due to sampling) and can fit in one executor’s memory. A simple solution like below works alright.

val w = Window.orderBy("sort_column")
df.withColumn(F.rank().over(w))

Once we are post ML stage (a.k.a. Modelling stage) we have to apply same operation to entire data set and soon we find ourselves in a soup as skipping partitionBy clause in a Window cause entire data set to get shuffled to a single executor and the job fails with OOM errors. Increasing memory of an executor doesn’t help either because its physically not possible to have an executor with 1 TB of memory.

One afternoon me and my colleague Deepak were brainstorming when we discovered a really neat property of Spark DataFrame’s orderBy function.

If you call orderBy on a DataFrame and do not call any further operation which causes the DataFrame to shuffle , then the data in partitions stay sorted.

Let me explain this with a visualization.

DataFrame is sorted by paritions after orderBy()

As you can see, after orderBy operation not only each partition is sorted from within but also among each other, all_values(partition-0) < all_values(partition-1) < all_values(partition-2).

We had a brilliant idea of leveraging this property of the DataFrame to assign a global rank to each row.

We first assigned partitionId to each of the row using Spark’s built in sparkPartitionId() method.

import org.apache.spark.sql.functions.sparkPartitionId
val partDf = df.orderBy("sort_column").withColumn("partitionId", sparkPartitionId())

After this we assigned rank within each partition using rank() function over a window.

import org.apache.spark.sql.expressions.Windowval w = Window.partitionBy("partitionId").orderBy("sort_column")// Since we have a partitionBy clause, data will be well split.val rankDf = partDf.withColumn("local_rank", rank().over(w))

This rank is a local one, within each partition, but we can leverage the fact that data is sorted among partitions as well i.e. all_values(partition-0) < all_values(partition-1) < all_values(partition-2)…

Local Rank for each partition

We’ll create a stats DataFrame derived from the rankDf. We’ll calculate max rank within each partition and also a running sum of the rank.

import org.apache.spark.sql.{functions => F}val tempDf =
rankDf.groupBy("partitionId").agg(F.max("local_rank").alias("max_rank"))
val w = Window.orderBy("partitionId").rowsBetween(Window.unboundedPreceding, Window.currentRow)val statsDf = tempDf.withColumn("cum_rank", F.sum("max_rank").over(w))
Stats DF derived from base DF

We have skipped the partitionBy clause in the window spec as the tempDf will have only N rows (N being number of partitions of the base DataFrame) and will only 2 columns, hence it will always fit in single executor’s memory.

Now we’ll self join statsDf and shift the cumulative sum by 1 row to calculate the sum_factor.

val joinDf = statsDf.alias("l").join(statsDf.alias("r"), $"l.partitionId" === $"r.partitionId" +1, "left").select(F.col("l.partitionId"), F.coalesce(F.col("r.cum_rank"),F.lit(0)).alias("sum_factor")) 

Once you have joinDf ready we’ll join it back to the rankDf using a broadcast join (this DataFrame will always be broadcast-able, as it will always have N rows and 2 columns). The final rank of the row would be local_rank + sum_factor (sum_factor being total rows till the current partition).

val finalDf = rankDf.join(F.broadcast(joinDf),Seq("partitionId"),"inner").withColumn("rank", $"local_rank" + $"sum_factor")
Globally ranked rows in a DataFrame

One gotcha when running this code is that you’ll constantly get a warning whenever you execute a Window without partitionBy clause.

WARN  WindowExec:66 - No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.

You can safely ignore it as we are applying such operations to DataFrames with predictable number of rows and columns (even if the source dataframe is huge).

And there you have it, Globally ranked rows in a DataFrame with Spark SQL.

In case you find any issues in my code or have any question, feel free to drop a comment below.

Till then Happy Coding ! :)

--

--

Sushil Kumar
The Startup

A polyglot developer with a knack for Distributed systems, Cloud and automation.