Spark partitioning: full control

Vladimir Prus
6 min readOct 25, 2021

--

In this post, we’ll learn how to explicitly control partitioning in Spark, deciding exactly where each row should go. It is an important tool for achieving optimal S3 storage or effectively interacting with external systems.

Photo by Jeffrey Workman on Unsplash

Motivation

Let’s revisit the motivating example from the earlier post:

You are implementing event ingestion. Streaming pipeline reads from Kafka and writes files to a landing location. Since streaming must run fast, the files it creates are small and contain different types of events together. Once a day, you want to compact the events into a few large files, separated by event type.

The code for daily compaction was easy:

spark
.read.parquet("s3://your-landing-bucket/day=2021-08-28")
.repartition(16, $"device_id")
.write
.partitionBy("type")
.parquet("s3://your-mart-bucket/day=2021-08-28")

The challenge is that different events have different frequencies. There might be a frequent event called appOpen, a less frequent productOpen event, and a purchase event that happens a few times a day. If we wish the resulting files to be large enough, say around 1GB, we might need 5 files for the app open event, but only a single file for the purchase event.

We can process each event type separately, but if there are 1000 different types, and we’re joining them with other data sources, the overhead becomes prohibitive. It would be cool to process all events, but instead of repartitioning by a single key with a single specific partition count, use different partition counts for different types.

Computing the desired partition

Let’s assume that for each event type, we already know the desired partition count. Let’s also assume there’s a column we can use for partitioning inside a given type. Then, we can easily compute the total number of partitions we need and allocate blocks of partitions to different event types.

This plan is easy to translate into code — we’ll add a new column desired_partition that is computed using two columns and the static partition counts

val counts = typedLit(Map(
"appOpen" -> 5,
"productOpen" -> 2,
"purchase" -> 1
))
val offsets = typedLit(Map(
"appOpen" -> 0,
"productOpen" -> 5,
"purchase" -> 7
))
df.withColumn("desired_partition",
offsets($"type") + pmod(hash($"device_id"), counts($"type")))

Fighting default partitioning

Having computed the desired partition column we can try to repartition:

df.repartition(8, $"desired_partition")

This does not do exactly what we want — some of the partitions will have more than one value of thedesired_partition column, and some partitions will have no data at all. Recall that repartition first computes a hash of the incoming keys, and then uses the hash, modulo the number of partitions, to determine target partitions. Hash values can collide, and will in fact collide quite often. To illustrate, run this code:

(0 to 7)
.toDF("desired_partition")
.withColumn("partition", pmod(hash($"desired_partition"), lit(8)))
.orderBy("desired_partition")
.show()

Turns out, that rows with desired partitions 1, 3, and 7 will all be hashed into partition 3. Rows with desired partitions 2 and 4 will end up in partition 6. Not exactly the explicit repartition we want.

Another approach we can try is repartitioning by a range, where a target partition is determined by the location of the key among a set of breakpoints.

df.repartitionByRange(8, $"desired_partition")

This will do exactly what we need — there will be 8 partitions, and each will have rows with exactly one value of the desired_partition column. But it will take considerably longer than our prior code. The reason is that repartition by range does not accept explicit ranges — it will sample the data to determine them, and that practically speaking doubles the effort.

Finally, we can convert to RDD and use RDD Partitioneer, where we have exact control. But as we know, conversion to RDD is itself expensive.

It’s about time we recall that extending Spark is also an option.

Explicit partitioning

We would like to create a new DataFrame operation that can partition exactly by a value of an integer column. As explained in the previous post, to add a new DataFrame operation we usually need to define a new logical operation, a new physical operation, and a strategy to convert between them. But since Spark already has a physical operation for repartitioning, called ShuffleExchangeExec, our task is even simpler — we just define a logical operator, like so:

case class ExplicitRepartition(
partitionExpression: Expression,
child: LogicalPlan,
numPartitions: Int)
extends RepartitionOperation {

val partitioning: Partitioning = {
new ExplicitPartitioning(Seq(partitionExpression),
numPartitions)
}
}

Our new operation extends RepartitioningOperation so Spark will know to convert it ShuffleExchangeExec . We only need to request our custom partitioning, defined below

class ExplicitPartitioning(
expression2: Seq[Expression], numPartitions2: Int
)
extends HashPartitioning(expression2, numPartitions2) {
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required)
}

override def partitionIdExpression: Expression = expression2.head
}

You might wonder why we started with wanting to explicitly specify a partition, and now are extending HashPartitioning which sounds like implementing the hashing we want to avoid. First, let’s look at this base class

case class HashPartitioning(
expressions: Seq[Expression],
numPartitions: Int)
extends Expression with Partitioning with Unevaluable {

def partitionIdExpression: Expression = Pmod(
new Murmur3Hash(expressions),
Literal(numPartitions)
)

}

As you see, the base class is responsible for hash-and-take-reminder logic, and we’re overriding it to use only the expression, so our code is correct. But, why do we extend a class to implement a one-line method? Our logical operation will be converted to a built-in physical operation called ShuffleExchangeExec — and it requires that one of 4 supported partitioning schemes is used (see the code). Hijacking HashPartitioning might not look nice, but it’s much easier than trying to replicate one of the key physical operators in Spark.

With these definitions, our explicit partitioning code becomes

df
.transform(df =>
val sparkSession = df.sparkSession
val e = RowEncoder(df.schema)
new DataFrame(sparkSession,
ExplicitRepartition($"desired_partition".expr,
df.queryExecution.analyzed, 8))
)

Of course, as earlier, we can introduce an implicit class to simplify the usage.

implicit class ExplicitRepartitionWrapper(df: DataFrame) {
def explicitRepartition(numPartitions: Int, partitionExpression: Column): DataFrame = {
val sparkSession = df.sparkSession
val e = RowEncoder(df.schema)
new DataFrame(sparkSession,
ExplicitRepartition(partitionExpression.expr, df.queryExecution.analyzed, numPartitions), e)
}
}

Then, the repartitioning code becomes just

df.explicitRepartition(8, $"desired_partition")

This is the key part of this post — we’ve developed an operator that partitions a dataframe exactly as we wish — a row ends up in the partition given a by column, without hash collisions or additional passes. The complete implementation can be found in Joom’s open-source repository.

Computing partition counts

Above, we have glossed over computing the optimal partition count for each event type. Our goal is to hit a desired target file size, so we actually want to estimate the parquet file size for each type, and then divide it by the desired file size to obtain partition counts. Parquet file size in turn depends on how well compression works on our data and can be only approximated. There are several strategies.

The easiest way is to count events, and multiply counts for each type by the approximate size of one event, collected from historic data. Alternatively, we can use the size of data written yesterday as a prediction of the size of data today. These approaches are fairly reasonable but can fail if there is a sudden spike in event count or event size.

Even better is to look at the amount of data we’re reading. Since Spark already has to collect this information to optimize the plans, we can also extract it, for example

val df = spark.read.parquet("s3://....")
val bytes = df.queryExecution.logical.stats.sizeInBytes.toLong

It often works great, but computes the total bytes, while we want to get the bytes per each event type.

Finally, we can use a brute force approach — list objects in S3, and directly compute the size. For example, we can do this:

val objects = S3Objects.withPrefix(s3, "landing-bucket", "events")
val r = s".*/type=(.*)/date=${dateString}/.*".r

val result = objects.flatMap(obj => obj.getKey() match {
case r(t, _) => Option((obj.getKey(), t, obj.getSize()))
case _ => None
}).toSeq

This is not very elegant but is probably the most efficient approach.

Conclusion

We’ve looked at explicitly controlling the partitioning of a Spark dataframe. The key motivation is optimizing table storage, where we want uniform data size distribution for all files. This can be also useful when interacting with external systems that require a particular partitioning. The standard partitioning mechanism in Spark falls short because it uses hashing internally, and hash collision often results in non-uniform distribution.

In this post, we’ve developed a custom Spark operator that performs such explicit partitioning — a row will end up exactly in the desired partition. The complete code is also available.

Thanks for reading, and subscribe for more data engineering stories.

--

--