Hey SparkSQL, What’s the Average Date?
Please read my recent blog posts at www.petersmith.net.
Apache Spark is one of the leading open source analytics frameworks, but it can’t do everything. In this blog post, we’ll look at a few different approaches to computing the average of date-typed data, which isn’t natively supported in Spark (as of version 2.3.0). Luckily though, Spark is highly customizable, allowing new analytic functions to be added quite easily.
You may be asking why you’d need to find the average of date values, especially since an average (or arithmetic mean) is typically an operation on numbers. For example, the average of January 7th, 1987, June 23rd, 1994, and December 10th, 1992, gives you the central date value of June 24th, 1991, but how is that useful?
One such case is anomaly detection, where it’s useful to identify date values falling outside the range of what’s considered “normal”. That is, if we determine the average of a series dates, as well as the standard deviation of those dates, we can then identify the outliers that fall beyond two standard deviations from the average.
Identifying the outliers gives us an opportunity to remove invalid data from our data set, or perhaps to start investigating the root cause of the anomaly. In this scenario, computing the average date from a large data set is important.
In this blog post, we’ll look at two different approaches to computing an average date using the Spark framework. We’ll also discuss some accuracy and performance implications.
The Problem
If Spark (particularly SparkSQL) already supported this functionally, we’d be able to compute the average of a series of dates by reading them into a Spark DataFrame (a table with rows and columns), then invoking the avg
function on the appropriate column.
In this example, we’ll read the dates from the single-column CSV filedates.csv
:
2017-01-02
2017-03-04
2017-05-06
2017-08-01
...
2017-10-12
To read this file into a Spark DataFrame, ensuring that the single column of data is interpreted as a Date
-typed value, we use the following code:
import org.apache.spark.sql.types._// The single column ("datecol") must be interpreted as a Date
val schema = StructType(
Seq(
StructField("datecol", DateType)
)
)// Define a new DataFrame, based off the content of the CSV file
val df = spark.read.schema(schema).csv("dates.csv")// Compute the average of the datecol column
df.agg(avg('datecol)))
Unfortunately, this simple solution fails with the following error message:
org.apache.spark.sql.AnalysisException: cannot resolve ‘avg(`datecol`)’ due to data type mismatch: function average requires numeric types, not DateType;
It’s clear that Spark’s built-in avg
function isn’t designed to support DateType
columns, so a workaround is required.
Approach 1 — Convert to Number, and Back Again
Given that the avg
function is intended to operate on numeric values, our first approach is to translate all the Date
values into corresponding Int
values, representing the number of days since a fixed point in time. We then perform the avg
operation, and convert the result back to aDate
value.
Unix-based systems use January 1st, 1970 as the point at which everything started (known as the “Epoch”), so we’ll do the same by converting all Date
values to the number of days since the Epoch. For example, January 2nd, 2017 equates to 17168, and March 4th, 1812 equates to -57646.
This numeric conversion is possible using SparkSQL’s datediff
function.
import java.sql.Date// Dates are first converted to number of days since this date
val baseDate = lit(Date.valueOf("1970-01-01"))// Compute a DataFrame containing the average number of days
val avgDayDataFrame = df.agg(
avg(
datediff('datecol, baseDate)
)
)
This approach works well, but gives an Int
as a return value, rather than a Date
.
scala> avgDayDataFrame.show+-----------------------------------------+
|avg(datediff(datecol, DATE '1970-01-01'))|
+-----------------------------------------+
| 17303.8|
+-----------------------------------------+
As it turns out, there’s no native SparkSQL function that does the opposite of datediff
, to obtain a Date
value from our numeric average. The date_add
function looked like it might work, but instead requires a constant Int
number of days, rather than taking the Int
result from a DataFrame.
Our solution is to extract the average value from the first column of the first row of the DataFrame, and then explicitly create a new Date
object (which expects the number of milliseconds since the Epoch). Note that we use Long
arithmetic here, to avoid exceeding the 2³¹ limit of Int
values.
// Trigger query (by running collect), and extract a native Long
// from the 0th column and 0th row of the DataFrame.
val avgDay = avgDayDataFrame.collect()(0).getDouble(0).toLong
// Now convert the number of days back to a Date type. The Date
// constructor requires the number of milliseconds since 1970-01-01.
val avgDate = new Date(avgDay * 24 * 60 * 60 * 1000)
This works, but isn’t very elegant, particularly since the final conversion to Date
is done outside the context of Spark DataFrames. We therefore can’t do additional DataFrame processing in the same Spark query. The solution is to encapsulate those last few lines in a Spark UDF (User Defined Function).
import org.apache.spark.sql.expressions.UserDefinedFunction// define a function that takes an Int, and returns the Date
val daysToDate: Int => Date = { days =>
new Date(days * 24 * 60 * 60 * 1000)
}// convert this to a UDF-based function that uses Spark's Column
// data type as the input and output.
val daysToDateUDF: UserDefinedFunction = udf(daysToDate)
A Spark UDF is essentially a function that accepts a Spark SQL Column
-typed value as input, and returns a Column
-typed value as output. This allows the function to be used entirely within a Spark query.
The new Spark query, returning a Date
value is now:
val avgDayDataFrame = df.agg(
daysToDateUDF(
avg(
datediff('datecol, baseDate)
)
)
)
val avgDay = avgDayDataFrame.collect()(0).getDate(0)
This gives us exactly what we need. Note that Spark UDFs are often reported to be inefficient, particularly because the Spark SQL optimizer is unable to understand them, and therefore unable to optimize them. In our case, we’re only running the UDF once per DataFrame (not once per row), so the performance impact should be minimal.
Approach 2 — User Defined Aggregate Functions
An alternative approach is to define a User Defined Aggregation Function (UDAF). Whereas a regular UDF acts on a single table cell, a UDAF operates on a full column to produce a single aggregated value.
Here’s how an avgdate
function would be used in a Spark query:
val avgdate = new AvgDateUDFval avgDayDataFrame = df.agg(avgdate('datecol))
val avgDay = avgDayDataFrame.collect()(0).getDate(0)
This syntax is much more readable than the previous example, given that you’re calling the avgdate
function on a Date
column, and getting back a resultingDate
value.
To define a Spark UDAF, we must extend the UserDefinedAggregateFunction
class and override the class members. The key members to be overridden are:
inputSchema
— Defines the type of values that UDAF can operate on (that is,Date
values).bufferSchema
— Defines the intermediate counters used during the aggregation. In this example, we track thecount
of the number of date values we’ve seen, as well as the runningtotal
of the dates.dataType
— Defines the type of the output data, in this caseDateType
.initialize()
— A method setting the counters to their initial values.update()
— A method called to add each newDate
value to our intermediate counter values. Note the special handling fornull
field values.merge()
— Given that Spark is a distributed analytics framework, this method joins together the counters from different Spark partitions that were potentially executed on different compute nodes.evaluate()
— Converts the intermediate counters into a finalDate
value. This is done by simply dividing thetotal
by thecount
, and then converting to aDate
type.
class AvgDateUDF extends UserDefinedAggregateFunction {
val BaseDate = Date.valueOf("1970-01-01")
// each value being aggregated has this type
override def inputSchema: StructType =
StructType(StructField("dateValue", DateType) :: Nil)
// intermediate values used during aggregation
override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("total", LongType) :: Nil
)
// output type of the aggregation
override def dataType: DataType = DateType
// This aggregation always returns a consistent output,
// given a consistent input
override def deterministic: Boolean = true
// Initialize our internal counters.
override
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// Update our counters with a new data value.
override
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val thisDate = input.getAs[Date](0)
if (thisDate != null) {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Long](1) +
thisDate.toLocalDate.toEpochDay
}
}
// merge counters from two different Spark partitions
override
def merge(buff1: MutableAggregationBuffer, buff2: Row): Unit = {
buff1(0) = buff1.getAs[Long](0) + buff2.getAs[Long](0)
buff1(1) = buff1.getAs[Long](1) + buff2.getAs[Long](1)
}
// Return the final value, as a Date
override def evaluate(buffer: Row): Any = {
val avgDays = buffer.getAs[Long](1) / buffer.getAs[Long](0)
java.sql.Date.valueOf(LocalDate.ofEpochDay(avgDays))
}
}
Note that for performance reasons, we access the intermediate counter variables as buffer(0)
and buffer(1)
, rather than using their symbolic "count"
and "total"
names.
Arithmetic Overflow
Even if you’re not a Scala expert, you can hopefully get the gist of the previous code. That is, initialize a counter to 0, and a sum to 0, and then for every new date value, add the number of days (since the base date) to the sum, and increment the counter. Finally, divide the total sum by the count of items seen.
One limitation of this approach is Arithmetic Overflow. That is, the total variable has type Long, implying it has a maximum value of 2⁶³-1 (or 9,223,372,036,854,775,807). That’s a pretty large number, but it’s still possible to overflow that data type and have it wrap around to zero. If that was to happen, we’d get a totally incorrect result.
In reality though, this is unlikely to happen with the Long data types (it would definitely be a problem with Int). Given that we’ll likely be dealing with recent dates (that is, near to the year 2018), most of the numbers we add to will be around 17,000 (days since 1970). We’d therefore need to find the average of 500 trillion date values before overflow would happen. It’s probably not worth worrying about this case.
In fact, looking at Spark’s avg
function, it uses either Double
data type which can reach 10³⁰⁸, or the BigDecimal
data type, which can be arbitrarily large (depending on your RAM). Clearly this is not a problem for most Spark users.
If we wanted to be really paranoid, there are ways to avoid overflow by either dividing the data set into equal-sized data sets, and then averaging the averages. Or perhaps use an approach of iteratively refining the average. We’ll leave those solutions for another day.
Performance
Finally, let’s get a rough indication of the performance of these different approaches. It’s interesting to measure performance, since User Defined Functions are reportedly slower than using native Spark functions which are handled better by Spark’s Catalyst Optimizer. Although our avgdate
function is easier to use in queries, it might just be slower.
In our tests, we used Amazon EMR-5.13.0 (with Spark 2.3.0) with one master and two core nodes of type m4.2xlarge (8 CPU and 32GB of RAM). The input data set was 10M rows of randomly-generated data in CSV format, with each row having 100 columns. The test case involved computing the average date for a particular date-typed column.
For each test case, the result was computed six times, with the data from the first test run being discarded (to ignore the impact of cold caches). The reported result is the average duration (in milliseconds) of the remaining five test runs.
- Base Case (675ms — StdDev 41ms) — The standard Spark min function was used as an indication of how fast native Spark functions could read data, therefore defining a base case scenario to compare against.
- Approach 1 (947ms — StdDev 13ms) — This is our first approach of computing the
avg
ofdatediff
. - Approach 2 (1390ms — StdDev 34ms) —Our second approach, using the
avgdate
user defined aggregation function.
To eliminate the cost of reading the CSV file into memory (the source data resides in Amazon S3), the .cache()
directive was used to pin the data into RAM. Therefore, the duration measurements are purely the time required to scan the date column and perform the averaging operation.
As you can see, our second approach takes almost 50% longer to compute the average date, compared to the first approach. It also takes twice as long as our base case of computing the minimum date value. Although some amount of code optimization is likely possible, our user-defined aggregation function is clearly not the best approach, even though it makes the code more readable.
Conclusion
Apache Spark is an excellent general-purpose framework for performing data analytics. Even though it comes with a variety of built-in analytic functions, it’s sometimes necessary to implement your own functions. Spark SQL provides a convenient mechanism for defining both cell-based UDFs and column-based UDAFs, making Spark queries easier to construct. Initial tests indicate that user-defined functions are less performant than native Spark functions.