CODEX

Scala Functional Programming with Spark Datasets

Eric Tome
CodeX
Published in
8 min readMar 22, 2021

--

This tutorial will give examples that you can use to transform your data using Scala and Spark. The focus of this tutorial is how to use Spark Datasets after reading in your data, and before writing it out… the Transform in Extract, Transform, Load (ETL).

One of the benefits of writing code with Scala on Spark is that Scala allows you to write in an object-oriented programming (OOP) or a functional programming (FP) style. This is useful when you have Java developers who only know how to write code in an OOP style. However, Spark is a distributed processing engine that benefits from writing code in an FP style. In my opinion, it’s also easier to write unit tests if you write functions that are pure, side-effect-free, and small. The goal of a Scala/Spark developer should be to move toward writing their applications in a functional style. This means using pure functions, immutable values, higher-order functions, and composition.

If these ideas are new to you, take some time to understand why they are important in the Spark world. An excellent resource is Functional Programming which is available at O’Reilly books.

Imports

Import only what you are going to use in your application. These are the imports we will use in this tutorial.

import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Encoder, Encoders, DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{avg, sum}
import java.sql.Date

Create your Object

We’re creating an object, which in Scala is the Singleton design pattern. We'll extend it with the trait App to make it runnable. You also have the option to use def main instead of App.

object FunctionalSpark extends App {

Helper Function

This is a helper function for casting a DataFrame to a Dataset. You should always strongly type your data.

def toDS[T <: Product: Encoder](df: DataFrame): Dataset[T] = df.as[T]

Create Datasets

We’ll create two datasets for use in this tutorial. In your own project, you’d typically be reading data using your own framework, but we’ll manually create a dataset so this code can be run in any environment.

Case classes are used to strongly type your data. When applying them to DataFrames this allows you to use the Dataset API in Spark. DataFrames are equivalent to Dataset[Row], which is an untyped Dataset.

final case class Person(
personId: Int,
firstName: String,
lastName: String)
final case class Sales(
date: Date,
personId: Int,
customerName: String,
amountDollars: Double)

This is our data which we’ll create using Seq types. We'll use two of them, one for people, and the other a set of sales data.

val personData: Seq[Row] = Seq(
Row(1, "Eric", "Tome"),
Row(2, "Jennifer", "C"),
Row(3, "Cara", "Rae")
)
val salesData: Seq[Row] = Seq(
Row(new Date(1577858400000L), 1, "Third Bank", 100.29),
Row(new Date(1585717200000L), 3, "Pet's Paradise", 1233451.33),
Row(new Date(1585717200000L), 2, "Small Shoes", 4543.35),
Row(new Date(1593579600000L), 1, "PaperCo", 84990.15),
Row(new Date(1601528400000L), 1, "Disco Balls'r'us", 504.00),
Row(new Date(1601528400000L), 2, "Big Shovels", 9.99)
)

Using Spark, we can read data from Scala Seq objects. The following code will create an StructType object from the case classes defined above. Then we have a function getDSFromSeq that takes parameters data and schema. We then use Spark to read our Seq objects while strongly typing them.

private val personSchema: StructType = Encoders.product[Person].schema
private val salesSchema: StructType = Encoders.product[Sales].schema
def getDSFromSeq[T <: Product: Encoder](data: Seq[Row], schema: StructType) =
spark
.createDataFrame(
spark.sparkContext.parallelize(data),
schema
).as[T]
val personDS: Dataset[Person] = getDSFromSeq[Person](personData, personSchema)
val salesDS: Dataset[Sales] = getDSFromSeq[Sales](salesData, salesSchema)

The show() method will display the data in the Dataset or DataFrame to console. This is the result:

Person Data

+--------+---------+-------------+
|personId|firstName | lastName|
+--------+---------+-------------+
| 1| Eric| Tome|
| 2| Jennifer| C|
| 3| Cara| Rae|
+--------+---------+-------------+

Sales Data

+----------+--------+----------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+----------------+-------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-04-01| 3| Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
|2020-10-01| 2| Big Shovels| 9.99|
+----------+--------+----------------+-------------+

Filtering

There are various ways to filter your data, here are examples:

personDS.filter(r => r.firstName.contains("Eric"))

Results:

+--------+--------+--------+
|personId|firstName|lastName|
+--------+--------+--------+
| 1| Eric| Tome|
+--------+--------+--------+

Filter where personId is equal to 1.

salesDS.filter(r => r.personId.equals(1))

Results:

+----------+--------+----------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+----------------+-------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
+----------+--------+----------------+-------------+

Filter where amountDollars is greater than 100.

salesDS.filter(r => r.amountDollars > 100)

Results:

+----------+--------+----------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+----------------+-------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-04-01| 3| Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
+----------+--------+----------------+-------------+

Filter where amountDollars is greater than 600.

salesDS.filter(r => r.amountDollars > 600)

Results:

+----------+--------+--------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+--------------+-------------+
|2020-04-01| 3|Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
+----------+--------+--------------+-------------+

Filter where amountDollars is between 600 and 5000.

salesDS.filter(r => r.amountDollars > 600 && r.amountDollars < 5000)

Results:

+----------+--------+------------+-------------+
| date|personId|customerName|amountDollars|
+----------+--------+------------+-------------+
|2020-04-01| 2| Small Shoes| 4543.35|
+----------+--------+------------+-------------+

Renaming Columns

You can create a new column or rename columns in spark using withColumn or withColumnRenamed. Let's say we want to rename all of our columns as they would appear in a database. We could call a withColumnRenamed for every column in the Dataset, like this:

df.withColumnRenamed("col1", "newcol1")
.withColumnRenamed("col2", "newcol2")
.withColumnRenamed("col3", "newcol3")
.withColumnRenamed("col4", "newcol4")
...
.withColumnRenamed("coln", "newcoln")

However, when modifying a large number of columns there are more elegant solutions.

  1. Create a case class that defines how your final set of data should look.
  2. Create a function that returns a Map[String, String] where the first string is the current column name, and the second is the new name.
  3. Create a function that takes that Map and folds over the input Dataset. The function within the fold is withColumnRenamed which takes the values from the Map for the current column name and a new name. A new Dataset is returned type with your final case class.
final case class SalesChangeColumnNames(
SALES_DATE: Date,
PERSON_ID: Int,
CUSTOMER_NAME: String,
SALES_IN_DOLLARS: Double)
def saleColumns: Map[String, String] =
Map(
"date" -> "SALES_DATE",
"personId" -> "PERSON_ID",
"customerName" -> "CUSTOMER_NAME",
"amountDollars" -> "SALES_IN_DOLLARS"
)
def renameColumns(ds: Dataset[Sales], m: Map[String, String]): Dataset[SalesChangeColumnNames] =
toDS {
m.foldLeft(ds.toDF())((acc, colnames) => acc.withColumnRenamed(colnames._1, colnames._2))
}
renameColumns(salesDS, saleColumns)

Result:

+----------+---------+----------------+----------------+
|SALES_DATE|PERSON_ID| CUSTOMER_NAME|SALES_IN_DOLLARS|
+----------+---------+----------------+----------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-04-01| 3| Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
|2020-10-01| 2| Big Shovels| 9.99|
+----------+---------+----------------+----------------+

Joining

Joining data in Spark is simple, but when joining two different sets of data, you will need to create a new case class and type the output of the join. Spark supports a wide variety of join types, left, right, full, anti-joins, and all the outer joins from the SQL standard. In this case, we join on personId and use a left join.

final case class JoinedData(
personId: Int,
firstName: String,
lastName: String,
date: Date,
customerName: String,
amountDollars: Double)
val joinedData: Dataset[JoinedData] =
toDS(personDS.join(salesDS, Seq("personId"), "left"))

Results:

+--------+---------+-----+----------+----------------+-------------+
|personId|firstName|lastN| date| customerName|amountDollars|
+--------+---------+-----+----------+----------------+-------------+
| 1| Eric| Tome|2020-01-01| Third Bank| 100.29|
| 1| Eric| Tome|2020-07-01| PaperCo| 84990.15|
| 1| Eric| Tome|2020-10-01|Disco Balls'r'us| 504.0|
| 3| Cara| Rae|2020-04-01| Pet's Paradise| 1233451.33|
| 2| Jennifer| C|2020-04-01| Small Shoes| 4543.35|
| 2| Jennifer| C|2020-10-01| Big Shovels| 9.99|
+--------+---------+-----+----------+----------------+-------------+

Using Map

Maps are a powerful feature in the Scala language. We use them here to transform our data from one type of object to another. The map function will iterate over each record in the Dataset mapping that record to a new object using user-defined functions (dollarToEuro, initials) and functions on the primitive String type (toUpperCase, toLowerCase, trim). We again create a case class that is used to create the object we are mapping to. Column order can also be changed using maps, you can see here we move date to the top of the column list. The output of our map produces a Dataset of type JoinedDataWithEuro.

final case class JoinedDataWithEuro(
date: Date,
personId: Int,
firstName: String,
lastName: String,
initials: String,
customerName: String,
amountDollars: Double,
amountEuros: Double)
def dollarToEuro(d: Double): Double = d * 1.19 def initials(firstName: String, lastName: String): String =
s"${firstName.substring(0, 1)}${lastName.substring(0, 1)}"
val joinedDataWithEuro: Dataset[JoinedDataWithEuro] =
joinedData.map(r =>
JoinedDataWithEuro(
r.date,
r.personId,
r.firstName.toUpperCase(), // modified column
r.lastName.toLowerCase(), // modified column
initials(r.firstName, r.lastName), // new column
r.customerName.trim(), // modified column
r.amountDollars,
dollarToEuro(r.amountDollars) // new column
)
)

Results:

+----------+--------+---------+--------+--------+----------------+-------------+------------------+
| date|personId|firstName|lastName|initials| customerName|amountDollars| amountEuros|
+----------+--------+---------+--------+--------+----------------+-------------+------------------+
|2020-01-01| 1| ERIC| tome| ET| Third Bank| 100.29| 119.3451|
|2020-07-01| 1| ERIC| tome| ET| PaperCo| 84990.15|101138.27849999999|
|2020-10-01| 1| ERIC| tome| ET|Disco Balls'r'us| 504.0| 599.76|
|2020-04-01| 3| CARA| rae| CR| Pet's Paradise| 1233451.33| 1467807.0827|
|2020-04-01| 2| JENNIFER| c| JC| Small Shoes| 4543.35| 5406.5865|
|2020-10-01| 2| JENNIFER| c| JC| Big Shovels| 9.99| 11.8881|
+----------+--------+---------+--------+--------+----------------+-------------+------------------+

Aggregating

Aggregation is done by using the DataFrame API, but we move back to a strongly typed Dataset after aggregation. There are a variety of functions that can be used in an aggregation: avg, sum, count, etc. The example below aggregates our sales data by user, summing and calculating the mean of the sales.

final case class TotalSalesByPerson(
personId: Int,
firstName: String,
lastName: String,
initials: String,
sumAmountDollars: Double,
sumAmountEuros: Double,
avgAmountDollars: Double,
avgAmountEuros: Double)
val totalSalesByPerson: Dataset[TotalSalesByPerson] =
toDS {
joinedDataWithEuro
.groupBy($"personId", $"firstName", $"lastName", $"initials").agg(
sum($"amountDollars").alias("sumAmountDollars"),
sum($"amountEuros").alias("sumAmountEuros"),
avg($"amountDollars").alias("avgAmountDollars"),
avg($"amountEuros").alias("avgAmountEuros")
)
}

Results:

+--------+---------+--------+--------+-----------------+------------------+------------------+------------------+
|personId|firstName|lastName|initials| sumAmountDollars| sumAmountEuros| avgAmountDollars| avgAmountEuros|
+--------+---------+--------+--------+-----------------+------------------+------------------+------------------+
| 2| JENNIFER| c| JC| 4553.34|5418.4746000000005| 2276.67|2709.2373000000002|
| 3| CARA| rae| CR| 1233451.33| 1467807.0827| 1233451.33| 1467807.0827|
| 1| ERIC| tome| ET|85594.43999999999|101857.38359999999|28531.479999999996| 33952.4612|
+--------+---------+--------+--------+-----------------+------------------+------------------+------------------+

Conclusion

This was a brief introduction to how to process data in Spark using Scala. We focused on functional implementations of transformations using Datasets. If you have any questions, feel free to drop a comment!

--

--

Eric Tome
CodeX
Writer for

Solutions Architect at Databricks. I love working with data, math, coding, music, soccer, working out, and video games.