User Defined Aggregate Functions (UDAFs) in Apache Spark
Introduction
The User Defined Aggregate Functions, also known as UDAFs, are user-defined functions that act on multiple rows at a time. As the name Aggregate suggests, UDAFs are used in aggregation scenarios to produce one row for each set of data they operate on. The sets of data, most commonly, are the result of grouping operations i.e. groupBy
and partitionBy
in window functions. The purpose of UDAFs is similar to User Defined Functions (UDFs) i.e. to allow the user to implement custom functionality that doesn’t come out of the box with Spark.
The official documentation describes UDAFs as:
User-Defined Aggregate Functions (UDAFs) are user-programmable routines that act on multiple rows at once and return a single aggregated value as a result.
Link to documentation:
Code Repository
If you want to follow along, or review the final code, it can be obtained from the following GitHub repository:
Project Structure
The code used in this article is in Scala, and maven is used to manage dependencies. At the time of writing, PySpark doesn’t support UDAFs directly, but the process to use a UDAF defined in Scala in PySpark has been covered in this article. Additionally, the Appendix contains the Maven POM file. The project should be executed with JDK 1.8:
Defining UDAFs
Following steps need to be taken to define a Spark UDF:
- Define the UDAF functionality by extending the
Aggregator
class (org.apache.spark.sql.expressions.Aggregator
) - Register the UDAFs as untyped UDAFs by calling the
spark.udf.register
function so that they can be used with untyped DataFrames. Note that this step is not needed if the aggregator is to be used with custom typed DataSet class. - Use the UDAF in Spark SQL clauses.
Note: The older org.apache.spark.sql.expressions.UserDefinedAggregateFunction
to define UDAFs has been deprecated, and the newer Aggregator
class should be used to define UDAFs:
Example UDF implementation — Calculating sum
Lets start with a simple example. lets define a UDAF that returns the sum of numbers for each group. This implementation is equivalent to implementing the sum
function found in Spark SQL. As mentioned before, the first step is to extend the Aggregator
class to create an object called SumUdaf
:
The implementation extends the Aggregator[-IN, BUF, OUT]
class. The three type parameters to the Aggregator class are:
IN
: the type of input data, this is set to Int because the UDAF will take and sum integers.BUF
: This is the type used by internal buffers that are computed during thereduce
calls.OUT
: This is the output type, type of the result that will be returned by the UDAF function.
Next, the following methods need to be overwritten when implementing the UDAF:
zero
: This function is called to initialize the internal buffers that are filled up to calculate the final value. For the sum operation, this initialized the initial value to 0.reduce
: This function is called to add to calculate intermediate values, and it is called for each value in the group. The official document recommends to modify the buffered object instead of creating a new object and assigning to it for better performance.merge
: This is called when two buffers containing intermediate values are merged. Because each buffer for the sum operation contains intermediate sum, we simply add the two intermediate sums here.finish
: This function is called at the end, and once per each group. This is used to convert the final state of the buffer into the final output. For the sum operation, we already have the final sum, so in this function, the sum is simply returned.bufferEncoder
: This specifies the encoder for intermediate buffer. In our example, it isEncoders.scalaInt
outputEncoder
: This specifies the encoder for output. In our example, it isEncoders.scalaInt
Please also review the official documentation of Spark UDAF and see the code examples.
Additionally, println
statements have been added to each call of the functions of the UDAF to take a peek into its internal values for better understanding.
Note: The print statements appear on the console while running the code example because in this article, the code examples are executed on Spark in local mode, where all the data and execution happens in the Spark driver. If this code is executed on a Spark cluster, the output of the print statements will not appear on the logs of the driver, and they will have to be accessed by looking into the stdout
logs of each executor on each node. Spark provides ways to do this, but this is out of scope for this article.
Lets look at the data that this UDAF will run on. The following code creates the Spark Session, registers the UDAF, and then calls:
This is the main function that creates the test data and then groups the data, and executes the SumUdaf
function that we just defined. Also, note that the UDAF has been registered in SparkSQL, and therefore, it has to be called inside an expr
function as a string. The call is: .agg(expr("SumUdaf(number)"))
.
The test data is simple DataFrame with two columns, number
column contains numbers from 1 to 10, and the mod
column contains 1 for odd numbers and 0 for even numbers.
The data frame is aggregated on the mod
column, which will create two groups. One containing even numbers and the second one containing odd numbers:
These values
are the ones that the UDAF will operate on:
- Buffers buffers will be initialized to 0 by calls to the
zero
function. - Then, the distributed operation will begin, and
reduce
andmerge
functions will be called: - Each call to the
reduce
function will add each number of the group to an intermediate buffer. There will be one reduce call for each value in each group. - The
merge
function will be called to merge intermediate values, which in this case will be a simple operation. - Finally,
finish
function will be called, once for each group.
Following should be the final output of the job:
Lets execute the code, it produces the following output:
Here, in addition to the final output, we can see the merge
and reduce
calls as well, as well as the finish
calls.
Second example — calculating Average
For better understanding, lets take a look at another example. Calculating average is similar to calculating sum, but with one additional requirement: to keep the number of values in addition to the running sum. In this implementation, we will keep that as a tuple in intermediate values, and in the final
function call, will divide the sum by count. The values of both running sum and running count will be kept in an intermediate case class. Changing the types of the functions appropriately, the following UDAF computes the average:
This example is similar to the one in the official Spark UDAF documentation, except additional print statements that allow us to look at intermediate values and parameters passed to each function call for better understanding. Registering and running this UDAF produces the following output:
Calling UDAF in window function
UDAFs work on sets or groups of data, and therefore, like groupBy
, they can also be called on windows. Following example adds the average of each window as a column:
The above function gives the following output:
The added column group_avg
adds the average of each window, and it works the same way as it does in the previous groupBy
example. Additionally, the UDAF is not called for each row, rather once for each group, despite the column having values for each row in the data frame.
Advanced UDAF examples — Multiple columns, Maps and Structs
Next, for a more complex example, consider a dataset of readings of sensors measuring multiple parameters related to the environment, and they are being reported from two locations in our example:
The requirement is to keep the latest values of readings at each location. Additionally, not all sensors send readings at equal frequency. Humidity appears once every two readings, and light intensity is slower, once every four readings. So, taking the latest reading is more involved then just taking the latest value by timestamp.
In order to do this, lets write a UDAF that takes two columns: timestamp
and readings
, and maintains a running record as it iterates through values in the reduce function of the UDAF, and updates based on the timestamp to keep the last occurrence of reading from each sensor.
The following code accomplishes this:
In order to take multiple inputs, case class can be used to define each column required by the UDAF. The case class DataColumns
is defined for this purpose, and it is also the input value type (IN
). The case class has two fields:
timestamp
(java.sql.timestamp
): Represents thetimestamp
column.readings
(Map[String, Double]
): Represents the readings. Note that Map type in Spark can have only one type of key and one type of value, and therefore, we’ll have to treat integer values as Double as well.
The data is internally passed between steps as a mutable Map of another case class scala.collection.mutable.Map[String, ValueWithTime(var timestamp: Timestamp, var value: Double)]
. This map contains the latest value of each type of reading (temperature, air_quality, humidity, and light_intensity), and as new values come in, the intermediate map is updated with newer value of each reading based on the timestamp. The timestamp is stored as well for comparison. Please read the reduce and merge functions for this.
The Map is mutable, and the ValueWithTime fields are also mutable (var) because it is recommended to update existing structures in the UDAF, and to avoid creating new instances to conserve memory and time. In the reduce
and merge
functions, new instances are not being created, instead one of the existing instances is being updated. Please refer to the code examples in the official Spark documentation for more details: https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html
Also note that println
statements have been added to look into the execution for better understanding.
Finally, in the finish
function, the mutable map is converted to map, and is returned as a MapType
column. The org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
is used to encode intermediate and final values. Reference link:
The code is executed in the following driver function:
The code produces the following final result:
The final result contains the latest occurring values from each sensor for each location.
And following is some of the output from print statements in the reduce and merge functions:
Also, please note that to make things execute in order for easier understanding of each stage, the code has been executed with a single thread:
Next, lets see if the sensors data is present as Struct
instead of Map
:
In this case, the input type of the data will change, we will pass the two columns as org.apache.spark.sql.Row
, this is because Struct
is received in UDFs and UDAFs as Row
instances. But because two fields will be passed to the UDAF, the input Row
instance will contain both timestamp
field, and readings
field as Row
instance within the outer Row
. The following UDAF produces the required result:
In this UDAF, the intermediate and output types are the same as previous example. Only the reduce
function implementation is different, because it requires extracting the two columns from the incoming Row
instance, and then processing them. In addition to this, the way this UDAF is registered before use will also have the input encoder passed to it, which will be the schema of two columns:
Following main function can be used to execute the UDAF:
And it produces the following output:
Please note that these examples have been implemented as UDAFs for demonstration and understanding, otherwise it is possible to use built in Spark SQL function and aggregate functions to obtain the same result.
Using the UDAF in PySpark jobs
UDAFs written in Scala can be called in PySpark jobs, and it can be done in following steps:
- Add a
register
function to the UDAF, which will be called from PySpark to register the UDAF. - Compile the UDAF written in Scala along with its dependencies in a JAR file. Because the example code uses Maven to manage dependencies and build process, the Maven Assembly Plugin.
- Include the jar file in the spark submit call using the
--jars
argument, or while creating the Spark Session. - Call the
register
function from PySpark to register the UDAF and to make it available for use. - Use the UDAF.
Note: Please note that spark.udf.registerJavaFunction
function to register the UDAF doesn’t work for UDAF derived from the Aggregator
class, and the code fails with the exception pyspark.sql.utils.AnalysisException: UDF class com.tutorial.LatestValuesOfKeysMapUdaf doesn't implement any UDF interface
. I tried to find a way to register the UDAFs derived from the Aggregator
class, and finally found this Stack Overflow answer that worked correctly:
The register
function in UDAF implementation (LatestValuesOfKeysMapUdaf
) is:
Please refer to the Maven POM file in the Appendix to see the packaging of the JAR file. It can be packaged using the mvn package command, and the maven assembly plugin creates the jar with all dependencies:
And, in PySpark, the following code registers, and then calls the UDAF:
And when executed, it produces the required result:
The code repository containing the PySpark project is:
This concludes this article on User Defined Aggregate Functions. Please share your thoughts and questions in comments.
Appendix
POM File
Following is the POM file:
Spark Session creation in Scala
Time zone is set to UTC.