Python User-Defined Aggregate Functions: now generally available

We are pleased to announce that Python User-Defined Aggregate Functions (UDAFs) are now generally available!

🚀 Click here to get started with UDAFs 🚀

Overview

User-Defined Aggregate Functions (UDAFs) in Snowflake provide a framework for defining custom aggregations in Python.

UDAFs extend the functionality of standard aggregate functions by allowing you to define complex, domain-specific calculations that are not supported by Snowflake’s built-in aggregate functions. Whether you need to calculate a weighted average, perform advanced statistical analysis, or implement a unique aggregation logic, UDAFs offer the flexibility and customization to meet your data processing requirements.

Because UDAFs run in Snowpark, you can import 3rd party Python libraries to help implement your aggregation logic. For example, Apache Datasketches is a popular library of stochastic streaming algorithms, meaning they are optimized for operating on data with a single pass. This proparty makes these algorithms ideal for operating on streaming data where performance is a top-priority. See Felipe’s blog post for a deep-dive into Python UDAFs with Apache Datasketches.

Examples

Here’s a simple example of using Snowpark with Python to create and execute a custom aggregate function with in-line SQL:

CREATE OR REPLACE AGGREGATE FUNCTION python_avg(a INT)
RETURNS FLOAT
LANGUAGE PYTHON
RUNTIME_VERSION = 3.8
HANDLER = 'PythonAvg'
AS $$
from dataclasses import dataclass

@dataclass
class AvgAggState:
count: int
sum: int

class PythonAvg:
def __init__(self):
# This aggregate state is an object data type.
self._agg_state = AvgAggState(0, 0)
@property
def aggregate_state(self):
return self._agg_state
def accumulate(self, input_value):
sum = self._agg_state.sum
count = self._agg_state.count

self._agg_state.sum = sum + input_value
self._agg_state.count = count + 1
def merge(self, other_agg_state):
sum = self._agg_state.sum
count = self._agg_state.count

other_sum = other_agg_state.sum
other_count = other_agg_state.count

self._agg_state.sum = sum + other_sum
self._agg_state.count = count + other_count
def finish(self):
sum = self._agg_state.sum
count = self._agg_state.count
return sum / count
$$;
-- Example data
CREATE OR REPLACE TABLE sales(item STRING, price INT);
INSERT INTO sales VALUES ('car', 10000), ('motorcycle', 5000), ('car', 7500), ('motorcycle', 3500), ('motorcycle', 1500), ('car', 20000);
-- Call UDAF
SELECT python_avg(price) FROM sales;

And this is an implementation of “sum” using the Snowpark Python API:

from snowflake.snowpark.session import Session
from snowflake.snowpark.functions import udaf, col, call_function
from snowflake.snowpark.types import IntegerType

session = Session.builder.configs({...}).create()

@udaf(name="sum_int", replace=True, return_type=IntegerType(), input_types=[IntegerType()])
class PythonSumUDAF:

def __init__(self) -> None:
self._sum = 0

@property
def aggregate_state(self):
return self._sum

def accumulate(self, input_value):
self._sum += input_value

def merge(self, other_sum):
self._sum += other_sum

def finish(self):
return self._sum

# Example data
df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
# Call UDAF
df.agg(PythonSumUDAF("a")).collect()
# Alternate syntax to call the UDAF by name
df.select(call_function(PythonSumUDAF.name, col("a")).alias("sum_a")).collect()

Try It Out

We invite you to try out User-Defined Aggregate Functions in Snowflake:

--

--