Productionize your ML workflow using Snowflake Task/DAG APIs
Tasks can be defined to automate data engineering workflows, specifically recurring tasks, natively in Snowflake. However, writing a full-scale pipeline with many tasks has been hard. Often, customers have built their own framework on top of Snowflake (probably in Python) to solve a number of programmability issues: constructing logical groups of tasks, retries, dynamic behavior, configurability, deployment, code-reuse, backfilling, catchup, ad hoc runs, logging, RCA.
Today, many data engineering workloads (inside and outside of Snowflake) are written in Python — specifically for machine learning (ML). So, around a year ago, engineers across Snowflake worked to solve this exact problem: How can Tasks be made even better with DAG & Python programmability?
To start answering that question and addressing customer needs, Snowflake introduced the following features:
- Richer Task Graph features in GA
- Finalizer Task serves as a try-finally block for a Task Graph. The finalizer task is guaranteed to run if the Task Graph is executed and ensures proper resource cleanup and completion of necessary steps. For example, if a Task Graph run uses intermediate tables to track data for processing and fails before the table rows are consumed, the next run will encounter duplicate rows and reprocess data, resulting in a longer execution time or wasted compute resources. The finalizer task drops the rows or truncates the table as needed to avoid these issues.
- Runtime Reflection Variables isa system function to customize and obtain key runtime information about their Task Graph. Now, users can use various out-of-the-box functions to determine a time range to process data from a stream, for example.
- Retry Failed Task Graph enables rerunning a Task graph from the last failed location — skipping previously successful task completions on a Task Graph — thereby saving computational resources and reducing latency.
- Task Graph Configuration enables customers to specify configurations that can be used by all tasks in the Task Graph. Previously, users would hardcode configuration values making it difficult to manage their Task Graphs. Now, users can easily clone Tasks and provide different configurations for pre-production or production environments.
- Task Predecessor Return Value enables users to create dynamic behavior such as the conditional execution of a Task based on the return values of any predecessor task.
- Snowflake Python API in public preview (PuPr)
- Snowflake Python API provides Python APIs for all Snowflake resources across Data Engineering, Snowpark, and App workloads. This comes with full support for Tasks & DAG.
- Developers can simply run pip install snowflake to get started.
- Enhanced Task UI in Snowsight in private preview (PrPr)
- This enables users to easily manage and monitor all tasks from a scheduled graph in one place with their runtime, failure errors, and so on.
- Snowpark ML
- Modeling API GA: Python API for preprocessing data, feature engineering, and training models inside Snowflake following familiar Python frameworks such as scikit-learn and XGBoost.
- Model Registry API PuPr: Snowpark Model Registry allows customers to securely manage models and their metadata in Snowflake, regardless of origin.
In this blog, we’ll go over a real-world example of using and managing a production pipeline end-to-end within Snowflake.
An example: A daily training on transaction data for fraud detection
It is very common for companies to have their customer’s data (transactions, enrollment, etc.) constantly coming in and being stored in an ongoing transactions table. For this demo, we are going to use SnowparkML to predict whether financial transactions are fraudulent or not. This will not only show how you can use Snowpark ML for the entirety of a typical ML workflow, including development and deployment but also how it can be seamlessly deployed to Snowflake for automation using DAG API.
To do that, we have both customer-level and account-level data, as well as labeled (fraud/no fraud) individual transaction data. Let’s assume there are tables for them FINANCIAL_ACCOUNT and FINANCIAL_TRANSACTIONS, where the latter is constantly getting updated with the latest data as they happen. We want to run a daily training job on the financial transactions data from the prior seven days of data to predict fraud. This new model will be tested against the current production model. If it passes the evaluation with test data, the daily production pipeline will automatically push the new model into production by marking it live.
First of all, let’s talk about some housekeeping and setup for our pipeline. For better debugging and isolation, we should run our production pipeline under a separate database, say, DAG_RUNS. Under that, we would keep all production settings such as task definitions under schema, say SCHEDULED_RUNS. Every daily run of the pipeline will have its own schema like RUN_<runid>, created exclusively for one run of the DAG and deleted at the end. In case of failure, it will be easy to debug and peek into that particular schema.
Let’s first define a simple DAG.
- We have defined a DAG or pipeline named daily_training which by default runs on the given warehouse.
- A “daily” schedule has been set using schedule.
- The use of use_func_return_value tells DAG to treat the return value of the Python functions to be used as the return value of the task.
- We require a stage_location to be provided where Python API will store the serialized version of the tasks.
- All the tasks under this DAG would run with this default list of packages and the warehouse. Both the warehouse and packages for any particular task under the DAG can be overridden to have different values.
- A config representing Task Graph Configuration is automatically made available to every child task to use.
With that, we now need to set up the environment for every run of the DAG which includes setting some variables, creating the RUN_<runid> schema, and so on. This task would create a config json for the rest of the DAG and return. Note that, this config is dynamically created at the runtime based on various runtime info using get_runtime_info(). However, there is another DAG graph config that is statically set at the time of creation of the DAG and can be retrieved via get_task_graph_config().
As you saw above, we just returned some config from the first task (function for now, will be set up as a task later). Within DAG, other tasks can retrieve those predecessor’s return values, task graph config, root task name, etc via a convenient TaskContext API.
Now it is time to select transaction data from START_TIME to END_TIME and perform some feature engineering to make it easy for the model training next. In the end, it writes the final result into a table named final_data inside the RUN_<runid> schema. This is very useful for debugging if DAG fails in a subsequent step. While this table name can be made constant and put in DAG config, here we are just passing from feature_eng task to subsequent tasks using return value just as an example.
This step is very time-consuming. So it makes sense to be a separate job. Often the same features are used in multiple models. So technically, DAG can split into multiple sub-DAGs from here.
Lots of details here are removed for brevity. However, you can find them here.
Split data for training & testing
This job just splits the data randomly and writes to two different tables for subsequent jobs.
This is the main model training. Note that we are using Snowpark ML Modeling API (snowflake.ml.modeling.xgboost.XGBClassifier) for training and Model Registry (snowflake.ml.registry) for logging the model here.
Now that training is done and the model is logged, it is time for evaluation against the test data. We’ll also update the registry with the metric computed here. This will be very useful for future debugging. While we are computing multiple metrics, we are just returning accuracy because that is what we would like to use to split the DAG into two — if accuracy is higher than a threshold, go ahead and push to production.
Push to production
In this step, we’ll just mark the model as live, and schedule the next run of batch inference with pick it up. Note, the batch inference side is not shown here, see model registry documentation to learn more. But before that, we would like to put a condition ensuring accuracy > 0.99.
Finally, clean up
Here we would just clean up the schema (RUN_<runid>) we created at setup(). Note that, we could have employed multiple strategies here:
- Keep the schema for some time around and have a separate task to clean very old schemas.
- Delete the intermediate schema even if DAG fails to reduce cost using Finalizer (the DAG API (PuPr) will be improved upon with the Python API coming soon).
- Delete the intermediate schema only in case of a successful run of the DAG; otherwise leave it for on-call to delete manually.
Here we have chosen option three,and therefore added a regular last task as follows.
Where did I define my graph?
So far we have defined a whole bunch of Python functions depicting a 7-tasks pipeline, or, DAG. However we have not yet created and pushed the DAG to Snowflake. Let’s do that now using the Snowflake Python API.
Now you can check the graph in SnowSight to see how the graph has been set up.
There is a lot going on here:
- Task graph configuration specifies the runtime environment for the entire DAG. With a simple change in DB_NAME one can deploy the same DAG in a different environment. This comes in very handy in real-world scenarios when a production pipeline runs in multiple environments like production, development and regtest environments.
- With branching, we can traverse different parts of the sub-tree of tasks. And the branch condition can be based on a value determined at runtime.
- For one-off testing/running the DAG, users can just skip the schedule and manually trigger a run with dag_op.run(dag).
- You may have noticed that you can even pass the raw function to the dependency definition without explicitly creating the DAGTask. The Python library automatically creates a task for you with the same name. But there are a few exceptions:
- You will need to explicitly create DAGTask for the first task
- DAGTaskBranch needs to be instantiated explicitly
- If you were to repeat certain functions in multiple tasks (eg cleanup() here), you need to explicitly give separate unique task names.
After you deploy() and run the DAG, you can see its status easily using:
Which outputs something like: RunId=1705129424736 State=EXECUTING.
In case the DAG fails, it is easy to see the error message directly from Python:
In the output, make sure to note the same RunId (because the output may contain multiple RunIds. The real error will be printed (as shown below) to be able to debug right away.
In case fixing the bug does not require an update to the task definition, you could just do the following to retry the last failed task:
For one-off runs, you can also omit the schedule argument to the DAG object without worrying about it running in the background unnecessarily.
What about metrics from models used in production?
Now inspect all the models in the registry that have been published by this pipeline.
As you can see above, we have marked the latest model as default. This means, we have been successfully promoting the model. In other words, accuracy has been at par. On of the biggest advantage of model registry is to inspect and manage the models. So we would like to plot the accuracy over time from the models we have developed.
What we learned?
The uses of Tasks in Snowflake:
How to productionize a Python-based ML workload
How to effectively pass data across tasks
Python DAG API
How an effective UI makes life simpler by visualizing runtime, easy access to error logs
- Snowpark ML APIs
How simple it is to deal with data at Snowflake and train your model
How easy it is to manage those models using registry, which captures all the metadata like accuracy, which can be plotted and observed
All the code mentioned here can be found at sf-samples repo.
Get started today!
Do try our new Task/DAG APIs with Snowpark-ML, and we would love to hear your feedback at email@example.com