CODEX
Scala Functional Programming with Spark Datasets
This tutorial will give examples that you can use to transform your data using Scala and Spark. The focus of this tutorial is how to use Spark Datasets after reading in your data, and before writing it out… the Transform
in Extract, Transform, Load
(ETL).
One of the benefits of writing code with Scala on Spark is that Scala allows you to write in an object-oriented programming (OOP) or a functional programming (FP) style. This is useful when you have Java developers who only know how to write code in an OOP style. However, Spark is a distributed processing engine that benefits from writing code in an FP style. In my opinion, it’s also easier to write unit tests if you write functions that are pure, side-effect-free, and small. The goal of a Scala/Spark developer should be to move toward writing their applications in a functional style. This means using pure functions, immutable values, higher-order functions, and composition.
If these ideas are new to you, take some time to understand why they are important in the Spark world. An excellent resource is Functional Programming which is available at O’Reilly books.
Imports
Import only what you are going to use in your application. These are the imports we will use in this tutorial.
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Encoder, Encoders, DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{avg, sum}
import java.sql.Date
Create your Object
We’re creating an object
, which in Scala is the Singleton
design pattern. We'll extend it with the trait App
to make it runnable. You also have the option to use def main
instead of App
.
object FunctionalSpark extends App {
Helper Function
This is a helper function for casting a DataFrame to a Dataset. You should always strongly type your data.
def toDS[T <: Product: Encoder](df: DataFrame): Dataset[T] = df.as[T]
Create Datasets
We’ll create two datasets for use in this tutorial. In your own project, you’d typically be reading data using your own framework, but we’ll manually create a dataset so this code can be run in any environment.
Case classes are used to strongly type your data. When applying them to DataFrames
this allows you to use the Dataset
API in Spark. DataFrames are equivalent to Dataset[Row], which is an untyped Dataset.
final case class Person(
personId: Int,
firstName: String,
lastName: String) final case class Sales(
date: Date,
personId: Int,
customerName: String,
amountDollars: Double)
This is our data which we’ll create using Seq
types. We'll use two of them, one for people, and the other a set of sales data.
val personData: Seq[Row] = Seq(
Row(1, "Eric", "Tome"),
Row(2, "Jennifer", "C"),
Row(3, "Cara", "Rae")
) val salesData: Seq[Row] = Seq(
Row(new Date(1577858400000L), 1, "Third Bank", 100.29),
Row(new Date(1585717200000L), 3, "Pet's Paradise", 1233451.33),
Row(new Date(1585717200000L), 2, "Small Shoes", 4543.35),
Row(new Date(1593579600000L), 1, "PaperCo", 84990.15),
Row(new Date(1601528400000L), 1, "Disco Balls'r'us", 504.00),
Row(new Date(1601528400000L), 2, "Big Shovels", 9.99)
)
Using Spark, we can read data from Scala Seq
objects. The following code will create an StructType
object from the case classes defined above. Then we have a function getDSFromSeq
that takes parameters data
and schema
. We then use Spark to read our Seq
objects while strongly typing them.
private val personSchema: StructType = Encoders.product[Person].schema
private val salesSchema: StructType = Encoders.product[Sales].schema def getDSFromSeq[T <: Product: Encoder](data: Seq[Row], schema: StructType) =
spark
.createDataFrame(
spark.sparkContext.parallelize(data),
schema
).as[T] val personDS: Dataset[Person] = getDSFromSeq[Person](personData, personSchema)
val salesDS: Dataset[Sales] = getDSFromSeq[Sales](salesData, salesSchema)
The show()
method will display the data in the Dataset
or DataFrame
to console. This is the result:
Person Data
+--------+---------+-------------+
|personId|firstName | lastName|
+--------+---------+-------------+
| 1| Eric| Tome|
| 2| Jennifer| C|
| 3| Cara| Rae|
+--------+---------+-------------+
Sales Data
+----------+--------+----------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+----------------+-------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-04-01| 3| Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
|2020-10-01| 2| Big Shovels| 9.99|
+----------+--------+----------------+-------------+
Filtering
There are various ways to filter your data, here are examples:
personDS.filter(r => r.firstName.contains("Eric"))
Results:
+--------+--------+--------+
|personId|firstName|lastName|
+--------+--------+--------+
| 1| Eric| Tome|
+--------+--------+--------+
Filter where personId
is equal to 1
.
salesDS.filter(r => r.personId.equals(1))
Results:
+----------+--------+----------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+----------------+-------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
+----------+--------+----------------+-------------+
Filter where amountDollars
is greater than 100
.
salesDS.filter(r => r.amountDollars > 100)
Results:
+----------+--------+----------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+----------------+-------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-04-01| 3| Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
+----------+--------+----------------+-------------+
Filter where amountDollars
is greater than 600
.
salesDS.filter(r => r.amountDollars > 600)
Results:
+----------+--------+--------------+-------------+
| date|personId| customerName|amountDollars|
+----------+--------+--------------+-------------+
|2020-04-01| 3|Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
+----------+--------+--------------+-------------+
Filter where amountDollars
is between 600
and 5000
.
salesDS.filter(r => r.amountDollars > 600 && r.amountDollars < 5000)
Results:
+----------+--------+------------+-------------+
| date|personId|customerName|amountDollars|
+----------+--------+------------+-------------+
|2020-04-01| 2| Small Shoes| 4543.35|
+----------+--------+------------+-------------+
Renaming Columns
You can create a new column or rename columns in spark using withColumn
or withColumnRenamed
. Let's say we want to rename all of our columns as they would appear in a database. We could call a withColumnRenamed
for every column in the Dataset, like this:
df.withColumnRenamed("col1", "newcol1")
.withColumnRenamed("col2", "newcol2")
.withColumnRenamed("col3", "newcol3")
.withColumnRenamed("col4", "newcol4")
...
.withColumnRenamed("coln", "newcoln")
However, when modifying a large number of columns there are more elegant solutions.
- Create a case class that defines how your final set of data should look.
- Create a function that returns a
Map[String, String]
where the first string is the current column name, and the second is the new name. - Create a function that takes that
Map
and folds over the inputDataset
. The function within the fold iswithColumnRenamed
which takes the values from theMap
for the current column name and a new name. A newDataset
is returned type with your final case class.
final case class SalesChangeColumnNames(
SALES_DATE: Date,
PERSON_ID: Int,
CUSTOMER_NAME: String,
SALES_IN_DOLLARS: Double) def saleColumns: Map[String, String] =
Map(
"date" -> "SALES_DATE",
"personId" -> "PERSON_ID",
"customerName" -> "CUSTOMER_NAME",
"amountDollars" -> "SALES_IN_DOLLARS"
) def renameColumns(ds: Dataset[Sales], m: Map[String, String]): Dataset[SalesChangeColumnNames] =
toDS {
m.foldLeft(ds.toDF())((acc, colnames) => acc.withColumnRenamed(colnames._1, colnames._2))
} renameColumns(salesDS, saleColumns)
Result:
+----------+---------+----------------+----------------+
|SALES_DATE|PERSON_ID| CUSTOMER_NAME|SALES_IN_DOLLARS|
+----------+---------+----------------+----------------+
|2020-01-01| 1| Third Bank| 100.29|
|2020-04-01| 3| Pet's Paradise| 1233451.33|
|2020-04-01| 2| Small Shoes| 4543.35|
|2020-07-01| 1| PaperCo| 84990.15|
|2020-10-01| 1|Disco Balls'r'us| 504.0|
|2020-10-01| 2| Big Shovels| 9.99|
+----------+---------+----------------+----------------+
Joining
Joining data in Spark is simple, but when joining two different sets of data, you will need to create a new case class and type the output of the join. Spark supports a wide variety of join types, left, right, full, anti-joins, and all the outer joins from the SQL standard. In this case, we join on personId
and use a left
join.
final case class JoinedData(
personId: Int,
firstName: String,
lastName: String,
date: Date,
customerName: String,
amountDollars: Double) val joinedData: Dataset[JoinedData] =
toDS(personDS.join(salesDS, Seq("personId"), "left"))
Results:
+--------+---------+-----+----------+----------------+-------------+
|personId|firstName|lastN| date| customerName|amountDollars|
+--------+---------+-----+----------+----------------+-------------+
| 1| Eric| Tome|2020-01-01| Third Bank| 100.29|
| 1| Eric| Tome|2020-07-01| PaperCo| 84990.15|
| 1| Eric| Tome|2020-10-01|Disco Balls'r'us| 504.0|
| 3| Cara| Rae|2020-04-01| Pet's Paradise| 1233451.33|
| 2| Jennifer| C|2020-04-01| Small Shoes| 4543.35|
| 2| Jennifer| C|2020-10-01| Big Shovels| 9.99|
+--------+---------+-----+----------+----------------+-------------+
Using Map
Maps are a powerful feature in the Scala language. We use them here to transform our data from one type of object to another. The map function will iterate over each record in the Dataset mapping that record to a new object using user-defined functions (dollarToEuro
, initials
) and functions on the primitive String type (toUpperCase
, toLowerCase
, trim
). We again create a case class that is used to create the object we are mapping to. Column order can also be changed using maps, you can see here we move date to the top of the column list. The output of our map produces a Dataset of type JoinedDataWithEuro
.
final case class JoinedDataWithEuro(
date: Date,
personId: Int,
firstName: String,
lastName: String,
initials: String,
customerName: String,
amountDollars: Double,
amountEuros: Double) def dollarToEuro(d: Double): Double = d * 1.19 def initials(firstName: String, lastName: String): String =
s"${firstName.substring(0, 1)}${lastName.substring(0, 1)}" val joinedDataWithEuro: Dataset[JoinedDataWithEuro] =
joinedData.map(r =>
JoinedDataWithEuro(
r.date,
r.personId,
r.firstName.toUpperCase(), // modified column
r.lastName.toLowerCase(), // modified column
initials(r.firstName, r.lastName), // new column
r.customerName.trim(), // modified column
r.amountDollars,
dollarToEuro(r.amountDollars) // new column
)
)
Results:
+----------+--------+---------+--------+--------+----------------+-------------+------------------+
| date|personId|firstName|lastName|initials| customerName|amountDollars| amountEuros|
+----------+--------+---------+--------+--------+----------------+-------------+------------------+
|2020-01-01| 1| ERIC| tome| ET| Third Bank| 100.29| 119.3451|
|2020-07-01| 1| ERIC| tome| ET| PaperCo| 84990.15|101138.27849999999|
|2020-10-01| 1| ERIC| tome| ET|Disco Balls'r'us| 504.0| 599.76|
|2020-04-01| 3| CARA| rae| CR| Pet's Paradise| 1233451.33| 1467807.0827|
|2020-04-01| 2| JENNIFER| c| JC| Small Shoes| 4543.35| 5406.5865|
|2020-10-01| 2| JENNIFER| c| JC| Big Shovels| 9.99| 11.8881|
+----------+--------+---------+--------+--------+----------------+-------------+------------------+
Aggregating
Aggregation is done by using the DataFrame API, but we move back to a strongly typed Dataset after aggregation. There are a variety of functions that can be used in an aggregation: avg, sum, count, etc. The example below aggregates our sales data by user, summing and calculating the mean of the sales.
final case class TotalSalesByPerson(
personId: Int,
firstName: String,
lastName: String,
initials: String,
sumAmountDollars: Double,
sumAmountEuros: Double,
avgAmountDollars: Double,
avgAmountEuros: Double) val totalSalesByPerson: Dataset[TotalSalesByPerson] =
toDS {
joinedDataWithEuro
.groupBy($"personId", $"firstName", $"lastName", $"initials").agg(
sum($"amountDollars").alias("sumAmountDollars"),
sum($"amountEuros").alias("sumAmountEuros"),
avg($"amountDollars").alias("avgAmountDollars"),
avg($"amountEuros").alias("avgAmountEuros")
)
}
Results:
+--------+---------+--------+--------+-----------------+------------------+------------------+------------------+
|personId|firstName|lastName|initials| sumAmountDollars| sumAmountEuros| avgAmountDollars| avgAmountEuros|
+--------+---------+--------+--------+-----------------+------------------+------------------+------------------+
| 2| JENNIFER| c| JC| 4553.34|5418.4746000000005| 2276.67|2709.2373000000002|
| 3| CARA| rae| CR| 1233451.33| 1467807.0827| 1233451.33| 1467807.0827|
| 1| ERIC| tome| ET|85594.43999999999|101857.38359999999|28531.479999999996| 33952.4612|
+--------+---------+--------+--------+-----------------+------------------+------------------+------------------+
Conclusion
This was a brief introduction to how to process data in Spark using Scala. We focused on functional implementations of transformations using Datasets. If you have any questions, feel free to drop a comment!