Orchestration Framework for running parallel containerized jobs in Snowflake
Introduction
In the rapidly evolving landscape of containerized workloads for data processing and running ML workloads a, efficient orchestration of containerized workloads running in parallel has become a critical requirement for businesses to streamline their data pipelines. If you’ve been using containers, executing multiple jobs with the help of external tools like Argo, and you’re interested in learning how to run containerised jobs in parallel directly in Snowflake, this blog post is for you.
Understanding the custom Orchestration Framework
The Orchestration Framework for running multiple containerised jobs in parallel is a custom solution which is a completely a config driven framework. With a simple json file you can specify which jobs that should be executed with dependencies and spin up a fan-out and fan-in scenario for your orchestration workflow. You can create a DAG with the help of Snowflake tasks to define which container jobs should be executed in parallel and which should run in sequence.
Key features of the Custom Orchestration Framework
- Declarative workflow definitions (Fan-Out and Fan-In )
- Parallel job execution
- Automated resource management
- Built-in error handling and retry mechanisms
- Using Snowflake’s core data platform capabilities.
Scenario
Envision a system that orchestrates multiple containerized jobs running concurrently, creating a fan-out and fan-in workflows. To accomplish this, you can leverage Snowpark Containers Service Jobs and run your containerised pipelines directly in the platform.
By harnessing Snowflake’s features, particularly the Snowpark Container Job service and tasks, you can effortlessly create a custom framework that manages your complex workflow.
What do I need
In order to implement this scenario we will be using below Snowflake native features:
a) JSON file with all dependencies and parameters needed for a every job which is uploaded to a Snowflake Stage.
b) Snowpark Container Job Service .
c) Snowflake Tasks
d) Snowpark Python Stored Procedures
JSON File Example
Below is the sample of the json file. You can customize this json as per your workflow needs. Most of the parameters can be reused and will be common for any service job. Only the details and parameters specific to the containers will change.
[{
"task_name":"t_myjob_1",
"image_name":"/pr_llmdemo/public/images/my_job_1_image:latest",
"compute_pool_name":"PR_STD_POOL_XS",
"job_name":"myjob_1",
"table_name":"results_1",
"retry_count":0,
"after_task_name":"root_task"
},
{
"task_name":"t_myjob_2",
"image_name":"/pr_llmdemo/public/images/my_job_2_image:latest",
"compute_pool_name":"PR_STD_POOL_S",
"job_name":"myjob_2",
"table_name":"results_2",
"retry_count":0,
"after_task_name":"root_task"
},
{
"task_name":"t_myjob_3",
"image_name":"/pr_llmdemo/public/images/my_job_3_image:latest",
"compute_pool_name":"PR_STD_POOL_S",
"job_name":"myjob_3",
"table_name":"results_3",
"retry_count":2,
"after_task_name":"t_myjob_2"
},
{
"task_name":"t_myjob_4",
"image_name":"/pr_llmdemo/public/images/my_job_4_image:latest",
"compute_pool_name":"PR_STD_POOL_XS",
"job_name":"myjob_4",
"table_name":"results_4",
"retry_count":1,
"after_task_name":"t_myjob_2"
},
{
"task_name":"t_myjob_5",
"image_name":"/pr_llmdemo/public/images/my_job_5_image:latest",
"compute_pool_name":"PR_STD_POOL_XS",
"job_name":"myjob_5",
"table_name":"results_5",
"retry_count":0,
"after_task_name":"t_myjob_1"
},
{
"task_name":"t_myjob_6",
"image_name":"/pr_llmdemo/public/images/my_job_6_image:latest",
"compute_pool_name":"PR_STD_POOL_S",
"job_name":"myjob_6",
"table_name":"results_6",
"retry_count":0,
"after_task_name":"t_myjob_3,t_myjob_4,t_myjob_5"
}
]
- task_name — Is the name of the task that will be created for that job which will run a specific service job.
- image_name — Name of the image in Snowpark container repository
- compute-pool — What is the compute pool name that you want the container to run on. This compute pool must be already provisioned.
- job_name — Name of the SPCS service job
- table_name — The application in the container will write the result to the table name mentioned in the json file and its only specific to this demo. You can include the required inputs needed by your image.
- retry_count — The framework will automatically attempt to re-execute the failed jobs. It will continue to retry the job up to the number of times specified by the
retry_count.
- after_task_name — This fields build the dependency graph. The
after_task_name
field is particularly powerful in creating fan-out and fan-in scenarios.
Demo: Creating a Simple Parallel Containerised Job Workflow
Let’s walk through a example of creating a parallel containerised job workflow using the custom Orchestration Framework. In this demo, we’ll create a workflow that will implement fan-out and fan-in scenario using the above JSON file. We will use the same image for all the jobs but you can specify different images for each job in the JSON file.
There are two Python stored procs through which the workflow is implemented :
- create_job_tasks — Creates the DAG for workflow orchestration
- ExecuteJobService — Runs the container service jobs
You find the implementation details about these Python SP’s in the below section.
Setup.sql
Creating required compute to run the containers. Here we are using the image from the this tutorial. Follow the instructions to create the image as mentioned in the tutorial.
use role accountadmin;
CREATE COMPUTE POOL pr_std_pool_xs
MIN_NODES = 1
MAX_NODES = 1
INSTANCE_FAMILY = CPU_X64_XS;
DESCRIBE COMPUTE POOL PR_STD_POOL_XS;
CREATE COMPUTE POOL PR_STD_POOL_S
MIN_NODES = 1
MAX_NODES = 2
INSTANCE_FAMILY = CPU_X64_S;
show compute pools like 'PR_STD_POOL_S';
-- You can use any role that you have created instead of SPCS_PSE_ROLE
grant usage on compute pool pr_std_pool_xs to role SPCS_PSE_ROLE;
grant usage on compute pool pr_std_pool_s to role SPCS_PSE_ROLE;
use role SPCS_PSE_ROLE;
CREATE OR REPLACE STAGE JOBS DIRECTORY = (
ENABLE = true);
CREATE IMAGE REPOSITORY IF NOT EXISTS IMAGES;
show image repositories;
show compute pools like 'pr%';
Creating logging tables and UDTF for tracking the tasks status for a DAG
We are creating two logging tables.
jobs_run_stats table — This table is used to log individual container job status. If there are any errors while running the container then we log those errors in this table. This is used in the SP which is executing the SPC Service Jobs.
task_logging_stats — This table is used to log the status (success or failure) of every task part of the DAG. This is used in the finalizer task.
use role SPCS_PSE_ROLE;
-- logging individual job status. This is used by the SP which is executing the SPCS Service Jobs
create or replace table jobs_run_stats( root_task_name string, task_name string, job_status string,GRAPH_RUN_ID string , graph_start_time timestamp_ltz, errors string, created_date datetime default current_timestamp());
-- Tracking all tasks part of the task graph. Used by the finalizer task
create table task_logging_stats (GRAPH_RUN_GROUP_ID varchar, NAME varchar, STATE varchar , RETURN_VALUE varchar,QUERY_START_TIME varchar,COMPLETED_TIME varchar, DURATION_IN_SECS INT,ERROR_MESSAGE VARCHAR);
-- UDTF for getting the task status for the graph - TASK_GRAPH_RUN_STATS
create or replace function TASK_GRAPH_RUN_STATS(ROOT_TASK_ID string, START_TIME timestamp_ltz)
returns table (GRAPH_RUN_GROUP_ID varchar, NAME varchar, STATE varchar , RETURN_VALUE varchar,QUERY_START_TIME varchar,COMPLETED_TIME varchar, DURATION_IN_SECS INT,
ERROR_MESSAGE VARCHAR)
as
$$
select
GRAPH_RUN_GROUP_ID,
NAME,
STATE,
RETURN_VALUE,
to_varchar(QUERY_START_TIME, 'YYYY-MM-DD HH24:MI:SS') as QUERY_START_TIME,
to_varchar(COMPLETED_TIME,'YYYY-MM-DD HH24:MI:SS') as COMPLETED_TIME,
timestampdiff('seconds', QUERY_START_TIME, COMPLETED_TIME) as DURATION,
ERROR_MESSAGE
from
table(INFORMATION_SCHEMA.TASK_HISTORY(
ROOT_TASK_ID => ROOT_TASK_ID ::string,
SCHEDULED_TIME_RANGE_START => START_TIME::timestamp_ltz,
SCHEDULED_TIME_RANGE_END => current_timestamp()
))
$$
;
Workflow DAG — Fan-out and Fan-in Workflow Implementation (create_job_tasks Python SP)
The code has the logic which creates the fan-in and fan-out workflow and does the following tasks:
- Based on the config file passed during its invocation it will create the Snowflake task graphs (for fan-out and fan-in scenario) and calls the Python SP (ExecuteJobService) created in the following section along with parameters fetched from the json config file( includes the compute pool, image name, retry count etc).
- Every task has the dependency on other task(s). Example T1 is dependent on root_task, T2 is dependent on root and T3 is dependent on T1 which implements the dependency workflow that is required.
- The code also creates a finalizer task which tracks the status of all the tasks( failure or Success) and logs it into the table task_logging_stats.
use role SPCS_PSE_ROLE;
CREATE OR REPLACE PROCEDURE create_job_tasks(file_name string)
RETURNS string
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'create_jobservice_tasks'
AS
$$
from snowflake.snowpark.files import SnowflakeFile
import json
def create_jobservice_tasks(session, file_name):
parent_task_name = 'root_task'
parent_task_sql = f'''CREATE OR REPLACE TASK {parent_task_name}
USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = 'XSMALL'
SCHEDULE = '59 MINUTE'
AS
SELECT CURRENT_TIMESTAMP() ;'''
session.sql(f'''{parent_task_sql}''').collect()
print(parent_task_sql)
with SnowflakeFile.open(file_name) as j:
json_data= json.load(j)
for idx, task in enumerate(json_data):
task_name = task['task_name']
after_task_name = task['after_task_name']
task_sql = f"CREATE OR REPLACE TASK {task_name} "
task_sql += f" WAREHOUSE = xs_wh "
task_sql += f" AFTER {after_task_name} "
task_sql += f" AS CALL ExecuteJobService('{task['job_name']}','{task['image_name']}','{task['compute_pool_name']}','{task['table_name']}',{task['retry_count']})"
# logger.info(f'{task_sql}')
session.sql(f'''{task_sql}''').collect()
print(task_sql)
# This is the Finalize task which gets the status for every task part of the DAG and loads into task_logging_stats table
session.sql(f"""
create or replace task GET_GRAPH_STATS
warehouse = 'xs_wh'
finalize = 'root_task'
as
declare
ROOT_TASK_ID string;
START_TIME timestamp_ltz;
begin
ROOT_TASK_ID := (call SYSTEM$TASK_RUNTIME_INFO('CURRENT_ROOT_TASK_UUID'));
START_TIME := (call SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_GRAPH_ORIGINAL_SCHEDULED_TIMESTAMP'));
-- Insert into the logging table
INSERT INTO task_logging_stats(GRAPH_RUN_GROUP_ID , NAME , STATE , RETURN_VALUE ,QUERY_START_TIME ,COMPLETED_TIME , DURATION_IN_SECS ,
ERROR_MESSAGE
)
SELECT * FROM TABLE(TASK_GRAPH_RUN_STATS(:ROOT_TASK_ID, :START_TIME)) where NAME !='GET_GRAPH_STATS';
end;
"""
).collect()
session.sql('alter task GET_GRAPH_STATS resume').collect()
session.sql(f'''SELECT SYSTEM$TASK_DEPENDENTS_ENABLE('root_task')''').collect()
return 'done'
$$;
Execute Job Service
This is the code which does the heavy lifting of running the container and is invoked from the tasks that is created and does the following :
- Accepts the name of the service job to be created, pool name where the service jobs will be executed on along with some parameters which are the inputs to the container and the retry count which is used to identify how many time should the code should retry executing the container before gracefully terminating incase of errors.
- For every service job execution, we are tracking the status whether Done or Failed and tracking the stats in jobs_run_stats table. It has details about the cotainer service jobs errors if any.
- This SP is invoked from another SP create_job_tasks which creates the task DAG based on the job config file which is created above.
use role SPCS_PSE_ROLE;
CREATE OR REPLACE PROCEDURE ExecuteJobService(service_name VARCHAR, image_name VARCHAR, pool_name VARCHAR,table_name VARCHAR,retry_count INT)
RETURNS VARCHAR
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'create_job_service'
AS
$$
from snowflake.snowpark.functions import col
import uuid
import re
import logging
import sys
logger = logging.getLogger("python_logger")
def get_logger():
"""
Get a logger for local logging.
"""
logger = logging.getLogger("service-job")
logger.setLevel(logging.INFO)
return logger
# Functions which invokes the execute service job
def execute_job(session, service_name, image_name,pool_name,table_name):
# Drop the existing service if it exists
session.sql(f'''DROP SERVICE if exists {service_name}''').collect()
sql_qry=f'''
EXECUTE JOB SERVICE
IN COMPUTE POOL {pool_name}
NAME={service_name}
FROM SPECIFICATION
'
spec:
container:
- name: main
image: {image_name}
env:
SNOWFLAKE_WAREHOUSE: xs_wh
args:
- "--query=select current_time() as time,''hello''"
- "--result_table={table_name}"
';
'''
#print(sql_qry)
try:
_=session.sql(sql_qry).collect()
except Exception as e:
logger.error(f"An error occurred running the app in the container: {e}")
finally:
job_status = session.sql(f''' SELECT parse_json(SYSTEM$GET_SERVICE_STATUS('{service_name}'))[0]['status']::string as Status
''').collect()[0]['STATUS']
return job_status
# This is the main function call invoked in the SP handler
# This functin calls execute_job to run the container with all the parameters required.
def create_job_service(session, service_name, image_name,pool_name,table_name,retry_count):
import uuid
logger = get_logger()
logger.info("job_service")
job_status = ''
job_errors = ''
current_root_task_name = ''
current_task_name = ''
current_graph_run_id = ''
current_graph_start_time = ''
try:
cnt = retry_count
# Execute the job service
logger.info(
f"Executing the Job [{service_name}] on pool [{pool_name}]"
)
job_status = execute_job(session, service_name,image_name, pool_name,table_name)
# Implementing retry mechanism. Fetching the retry count value from the config file per job
if job_status=='FAILED':
while(cnt >0):
r_cnt = retry_count+1 - cnt
logger.info(
f"Retrying Executing the Job [{service_name}] on pool [{pool_name}] - [{r_cnt}] out of {retry_count} times "
)
job_status = execute_job(session, service_name,image_name, pool_name,table_name)
if job_status == 'DONE':
break
cnt = cnt - 1
if job_status=='FAILED':
job_errors = re.sub(r"'", r"\\'",session.sql(f'''
select SYSTEM$GET_SERVICE_LOGS('{service_name}', 0, 'main')::string as logs;
''').collect()[0]['LOGS'])
else:
job_errors = ''
# Getting the DAG Task details. SYSTEM$TASK_RUNTIME_INFO can only work inside a task.
result = session.sql("""select
SYSTEM$TASK_RUNTIME_INFO('CURRENT_ROOT_TASK_NAME')
root_task_name,
SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_NAME')
task_name,
SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_GRAPH_RUN_GROUP_ID')
run_id,
SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_GRAPH_ORIGINAL_SCHEDULED_TIMESTAMP') dag_start_time
""").collect()[0]
current_root_task_name = result.ROOT_TASK_NAME
current_task_name = result.TASK_NAME
current_graph_run_id = result.RUN_ID
current_graph_start_time = result.DAG_START_TIME
#'a','b','c','2024-01-01'
#result.ROOT_TASK_NAME, result.TASK_NAME ,result.RUN_ID, result.DAG_START_TIME
# Inserting job status into logging table
_ = session.sql(f'''
INSERT INTO jobs_run_stats
(root_task_name,task_name,graph_run_id ,job_status,graph_start_time, errors ,created_date)
SELECT '{current_root_task_name}'
,'{current_task_name}'
,'{current_graph_run_id}'
,'{job_status}'
,'{current_graph_start_time}'
,'{job_errors}'
,current_timestamp()
''').collect()
return job_status
except Exception as e:
print(f"An error occurred: {e}")
if job_status=='FAILED':
job_errors = re.sub(r"'", r"\\'",session.sql(f'''
select SYSTEM$GET_SERVICE_LOGS('{service_name}', 0, 'main')::string as logs;
''').collect()[0]['LOGS'])
else:
job_errors = ''
session.sql(f"""
INSERT INTO jobs_run_stats(task_name,errors,graph_run_id,job_status,created_date)
SELECT '{service_name}',
'{job_errors}',
'{current_graph_run_id}',
'{job_status}',
current_timestamp()
""").collect()
return f'Error Occured.. Refer the job error column - {e}'
$$;
Running the Container Workflow Orchestration Utility
Here we are invoking the orchestration workflow SP which accepts the jobconfig file (uploaded to Snowflake stage) which has all the details required for the tasks to be created for Fan-Out and Fan-In scenarios.
Upload the JSON file to JOBS stage that you have created above.
call create_job_tasks(build_scoped_file_url(@jobs, 'jobconfig.json'));
After executing the SP below is the DAG that is created.
Checking the tasks part of the DAG created.
-- Checks the DAG task created for the root_task. You can see the column predecessor which mentions the dependent task name
select *
from table(information_schema.task_dependents(task_name => 'root_task', recursive => true));
Viewing execution Stats using Logging Tables
We have two logging tables that we have created and lets check the status of the each Container Service Job and also the status of every task that is invoking the service job.
-- View job run status. This is per Service job logging
select top 10 * from jobs_run_stats order by created_date desc;
Every service job(which is invoking the container) is having same GUID per DAG execution.
Below query gives additional metrics about the duration per task(which is executing the Container service jobs)
-- Query task logging status (by the finalizer task)
SELECT top 10 * FROM task_logging_stats ORDER BY CAST(QUERY_START_TIME AS DATETIME) DESC;
Viewing Error Details
If there are any error while running the containers we are tracking that info into the logging tables we have created. You can always go back and check for the error details from the jobs_run_stats table.
SP to drop the tasks — Cleanup
Below code will delete the tasks that were created as part of the framework.
CREATE OR REPLACE PROCEDURE drop_job_tasks()
RETURNS string
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'drop_tasks'
execute as caller
AS
$$
from snowflake.snowpark.files import SnowflakeFile
import json
def drop_tasks(session):
session.sql('alter task root_task suspend').collect()
res= session.sql(f''' select name
from table(information_schema.task_dependents(task_name => 'root_task', recursive => true))''').collect()
for r in res:
print(r.NAME)
session.sql(f'drop task {r.NAME}').collect()
session.sql('drop task GET_GRAPH_STATS').collect()
return 'Done'
$$;
-- Deleting the DAG (Task Graphs)
call drop_job_tasks();
Conclusion
The SPCS custom Orchestration Framework represents a solution for running parallel containerised job within the Snowflake using core platform native features.