Unlocking Spark’s Potential with Scala: split Data Frame by left join.

Alexander Lopatin
6 min readDec 24, 2023

--

In this post, I’m going to write function that use Scala’s function approach.

In my practice, there are a lot of ETLs that require to split some data frame into two data frames by condition. A common use case of splitting of data frame is splitting by left join. We make left join on two data frames and want to transform with different algorithms both parts: part that joined with right table and part that didn’t join with right table.

I’d like to use with algorithm multiple times. So let’s write a function that get two data frame as input parameters and a condition of join. For now, let’s write our function.

Here’s a function splitByLeftJoin:

def splitByLeftJoin(leftDF: DataFrame, rightDF: DataFrame, drop: Boolean)(condition: ConditionType): (DataFrame, DataFrame) = {  
val joinDF =
leftJoin(
leftDF,
rightDF.withColumn("condition", lit(true)),
condition
)
.withColumn("condition", coalesce(col("condition"), lit(false)))

val dropCondition =
if (drop) {
val columns = leftDF.columns.map(leftDF(_))
(df: DataFrame) => df.select(columns: _*)
} else {
(df: DataFrame) => df.drop("condition")
}

val (trueConditionDF, falseConditionDF) = splitByCondition(joinDF)(col("condition"))

val (trueConditionDFResult, falseConditionDFResult) =
(
trueConditionDF.transform(dropCondition),
falseConditionDF.transform(dropCondition)
)

(trueConditionDFResult, falseConditionDFResult)
}

Let’s break it down.

Input parameters

Our function gets two data frames leftDF and rightDF, drop parameter to say if we need to drop columns of right table from resulting data frames and condition of type ConditionType that describes condition of left join. I’ll explain what’s the matter with condition later.

Function leftJoin

leftJoin function is a heart of our module.

private def leftJoin(leftDF: DataFrame, rightDF: DataFrame, condition: ConditionType): DataFrame = {
condition match {
case ColumnType(cond) =>
leftDF.join(rightDF, cond, "left")
case SeqType(cond) =>
leftDF.join(rightDF, cond, "left")
case _ =>
throw new IllegalArgumentException
}
}

It takes three parameters and has a private modifier.

As you know, in Spark, to join two data frames you have to pass parameter of one of two types: Column or Sequence of Strings. I’d like to save this opportunity in my function splitByLeftJoin. So I have to find way to use both variations of join: with Column parameter and Sequence parameter.

But how to do that with minimum boilerplate code? First idea was to use overload functions and write two version of part were we do join.

But at the same time, I’d like to use ability of Scala to determine a multiple lists of input parameter. That way, I can write my condition as body of function, like this:

val (resultDFTrue, resultDFFalse) = 
splitByLeftJoin(leftDF, rightDF, true) {
col("left_column") === col("right_column")
}

Because of condition of join can be more complex than just equality of two columns it can be useful to use call the function like this.

But there is a problem with more than one list of input parameters in Scala 2. It doesn’t allow to overload a such type of functions when overloading is not in first list of input parameters.

So I tried a bunch of variations and found one that really works.

I declared a new trait and couple of case classes that extend my trait:

trait ConditionType

case class ColumnType(condition: Column) extends ConditionType
case class SeqType(condition: Seq[String]) extends ConditionType

And now, I use pattern matching in my leftJoin function to know which type of join to use: with Column or Sequence type.

There is still one issue. Now, user of my function has to create an object of one of two classes ColumnType or SeqType. I don’t want user to do that. I want user to let just put in my function expression of Column type or Sequence type with names of columns.

We can use explicit methods for this:

implicit def columnToColumnType(column: Column): ColumnType = ColumnType(column)
implicit def SeqToSeqType(seq: Seq[String]): SeqType = SeqType(seq)

These two methods define implicit conversion.

If we need somewhere an object of type ConditionType, Scala will use implicit conversion from one type to other type. (for further deatiles, visit this link: implicit conversion). Because of both ColumnType and SeqType extend ConditionType, our both implicit methods will works great when we use function splitByLeftJoin and put in it parameter of Column or Sequence type.

Scala 3 got the opportunity to overload multiple list of input parameters.
So, in Scala 3 we can just write as much as we want overloaded functions like that:
splitByLeftJoin(leftDF: DataFrame, rightDF: DataFrame, drop: Boolean)(condition: Column)
splitByLeftJoin(leftDF: DataFrame, rightDF: DataFrame, drop: Boolean)(condition: Seq[String])
For further details, visit this link: overload resolution
But because my project uses Scala 2.13, I have to do some tricky stuff.

Let’s sum up. We have an leftJoin private function that takes ConditionType parameter and find out actual type by pattern matching. ConditionType is gotten earlier with implicit conversion from both Column and Sequence types. We also create temporary constant column condition that will help us to divide our data frame.

Drop condition

We made a join and know we have to get two output data frame and return them.

But before we have to decide if we want to delete all columns from resulting data frames that were joined to them from rightDF. For this we have drop parameter. Because of any operator in Scala is a function we can write if … else that returns any value in new variable.

val dropCondition =
if (drop) {
val columns = leftDF.columns.map(leftDF(_))
(df: DataFrame) => df.select(columns: _*)
} else {
(df: DataFrame) => df.drop("condition")
}

We will use variable dropCondition later.

Split the data frame

splitByCondition is a very simple function that return a tuple of two data frame.

def splitByCondition(df: DataFrame)(condition: Column): (DataFrame, DataFrame) = {
(df.filter(condition), df.filter(not(condition)))
}

Where first data frame contains of rows where condition is true, and second data frame contains of rows where condition is false.

Return result

Now, we have two data frames. We perform dropCondition function on them and return results to our user.

val (trueConditionDFResult, falseConditionDFResult) =
(
trueConditionDF.transform(dropCondition),
falseConditionDF.transform(dropCondition)
)

(trueConditionDFResult, falseConditionDFResult)

That’s it.

Now, to split data frame by left join, let’s assume it’s a join with some small dictionary, we just have to call splitByLeftJoin function and put in it condition of join like that:

val (dataDFScope, dataDFNotScope) =
splitByLeftJoin(
dataDF,
dataScopeDF
){
dataDF(Account) === dataScopeDF(Account) && (
isnull(dataScopeDF(ACS.TxtField4))
|| contractDF(TxtField4) === dataScopeDF(TxtField4)
) && (
isnull(dataScopeDF(TxtField5))
|| dataDF(TxtField5) === dataScopeDF(ACS.TxtField5)
)
}

Below, you can see whole code of module with splitByCondition function.

object Split {

/**
* Performs the splitting of a DataFrame based on a condition.
* @param df DataFrame to be split based on the condition
* @param condition Condition expression, should only return true/false
* @return Tuple of two DataFrames, where:
* - the first DataFrame contains rows where the condition is true
* - the second DataFrame contains rows where the condition is false
*/
def splitByCondition(df: DataFrame)(condition: Column): (DataFrame, DataFrame) = {
(df.filter(condition), df.filter(not(condition)))
}

/**
* Performs the splitting of leftDF by performing a left join with rightDF.
* @param leftDF DataFrame to be split into two DataFrames based on the condition
* @param rightDF DataFrame to be joined with
* @param condition Join condition expression
* @return Tuple of two DataFrames, where:
* - the first DataFrame contains rows where the left join worked based on the condition
* - the second DataFrame contains rows where the left join did not work based on the condition
*
* Columns from rightDF will be removed from the resulting DataFrames.
*/
def splitByLeftJoin(leftDF: DataFrame, rightDF: DataFrame)(condition: ConditionType): (DataFrame, DataFrame) = {
splitByLeftJoin(leftDF, rightDF, drop = true) {
condition
}
}

/**
* Performs the splitting of leftDF by performing a left join with rightDF.
* @param leftDF DataFrame to be split into two DataFrames based on the condition
* @param rightDF DataFrame to be joined with
* @param drop Need to remove columns from rightDF in the resulting DataFrames
* @param condition Join condition expression
* @return Tuple of two DataFrames, where:
* - the first DataFrame contains rows where the left join worked based on the condition
* - the second DataFrame contains rows where the left join did not work based on the condition
*/
def splitByLeftJoin(leftDF: DataFrame, rightDF: DataFrame, drop: Boolean)(condition: ConditionType): (DataFrame, DataFrame) = {

val joinDF =
leftJoin(
leftDF,
rightDF.withColumn("condition", lit(true)),
condition
)
.withColumn("condition", coalesce(col("condition"), lit(false)))

val dropCondition =
if (drop) {
val columns = leftDF.columns.map(leftDF(_))
(df: DataFrame) => df.select(columns: _*)
} else {
(df: DataFrame) => df.drop("condition")
}

val (trueConditionDF, falseConditionDF) = splitByCondition(joinDF)(col("condition"))

val (trueConditionDFResult, falseConditionDFResult) =
(
trueConditionDF.transform(dropCondition),
falseConditionDF.transform(dropCondition)
)

(trueConditionDFResult, falseConditionDFResult)

}

private def leftJoin(leftDF: DataFrame, rightDF: DataFrame, condition: ConditionType): DataFrame = {
condition match {
case ColumnType(cond) =>
leftDF.join(rightDF, cond, "left")
case SeqType(cond) =>
leftDF.join(rightDF, cond, "left")
case _ =>
throw new IllegalArgumentException
}
}

trait ConditionType

case class ColumnType(condition: Column) extends ConditionType
case class SeqType(condition: Seq[String]) extends ConditionType

implicit def columnToColumnType(column: Column): ColumnType = ColumnType(column)
implicit def SeqToSeqType(seq: Seq[String]): SeqType = SeqType(seq)

}

--

--