ML on Snowflake at scale with Snowpark Python and XGBoost

Over the last year and half, our team has been working with hundreds of customers building on Snowflake with Snowpark and Python(by the way — checkout this post by my colleague Caleb Baechtold that compiles all the learnings about operationalizing Snowpark Python in production) and a question I frequently get is, “but we have really big data, how does it work at scale?”.

Many of the examples/demos you will find online are great at showing simple contrived examples that show core functionality, but few if any show examples of how it works in the context of large enterprise scale problems.

TPC-DS, is a useful large scale dataset that generally represents a generic business data schema(I’m not here to advocate that TPC-DS, or any dataset be the end all be all benchmark of any platform/or tool, I just think it can be a useful dataset as a starting point to test things at scale). Snowflake makes available TPC-DS in every Snowflake account as a data share, in both a 10 TB and a 100 TB edition(by the way, because it’s exposed as a data share you as the user do not pay any of the costs for storage. Only for any potential compute you use to query it). The 100 TB edition has over 560 billion rows in the fact tables. The 10 TB edition, 56 billion.

We can use Snowpark Python to build a relatively simple ML solution to a common business problem a hypothetical business like TPC has which is “I want to be able to predict the life-time value of my customers across all sales channels”. We will use the Snowpark Python Dataframe API to do the data prep/feature engineering, stored procedures and Snowpark Optimized Warehouses for training, and batch UDF’s for inference all without the data ever needing to leave Snowflake by utilizing Snowflake’s compute resources and ability to scale.

Let’s start with our data prep/feature engineering code. Fairly simply, we will aggregate sales by customer across all channels . We will then join that to our customer dimension tables to get potential features of interest.

store_sales_agged = store_sales.group_by('ss_customer_sk').agg(F.sum('ss_sales_price').as_('total_sales'))
web_sales_agged = web_sales.group_by('ws_bill_customer_sk').agg(F.sum('ws_sales_price').as_('total_sales'))
catalog_sales_agged = catalog_sales.group_by('cs_bill_customer_sk').agg(F.sum('cs_sales_price').as_('total_sales'))
store_sales_agged = store_sales_agged.rename('ss_customer_sk', 'customer_sk')
web_sales_agged = web_sales_agged.rename('ws_bill_customer_sk', 'customer_sk')
catalog_sales_agged = catalog_sales_agged.rename('cs_bill_customer_sk', 'customer_sk')
total_sales = store_sales_agged.union_all(web_sales_agged)
total_sales = total_sales.union_all(catalog_sales_agged)
total_sales = total_sales.group_by('customer_sk').agg(F.sum('total_sales').as_('total_sales'))
customer = customer.select('c_customer_sk','c_current_hdemo_sk', 'c_current_addr_sk', 'c_customer_id', 'c_birth_year')
customer = customer.join(address.select('ca_address_sk', 'ca_zip'), customer['c_current_addr_sk'] == address['ca_address_sk'] )
customer = customer.join(demo.select('cd_demo_sk', 'cd_gender', 'cd_marital_status', 'cd_credit_rating', 'cd_education_status', 'cd_dep_count'),
customer['c_current_hdemo_sk'] == demo['cd_demo_sk'] )
customer = customer.rename('c_customer_sk', 'customer_sk')
final_df = total_sales.join(customer, on='customer_sk')
session.use_database('tpcds_xgboost')
session.use_schema('demo')
final_df.write.mode('overwrite').save_as_table('feature_store')

We save this to a feature store table for reuse later in experimentation or production settings(more on how many customers think of Snowflake as an offline feature store another time).

Now we’re ready to train our model using a stored procedure. Snowflake simply makes available all the popular Python ML frameworks to you to utilize for ML training via our partnership with Anaconda. No need to learn a new library or syntax.

from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
from sklearn.metrics import mean_squared_error
from sklearn.compose import ColumnTransformer
from xgboost import XGBRegressor
import joblib
import os

def train_model(session: snowflake.snowpark.Session) -> float:
snowdf = session.table("feature_store")
snowdf = snowdf.drop(['CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
snowdf_train, snowdf_test = snowdf.random_split([0.8, 0.2], seed=82)

# save the train and test sets as time stamped tables in Snowflake
snowdf_train.write.mode("overwrite").save_as_table("tpcds_xgboost.demo.tpc_TRAIN")
snowdf_test.write.mode("overwrite").save_as_table("tpcds_xgboost.demo.tpc_TEST")
train_x = snowdf_train.drop("TOTAL_SALES").to_pandas() # drop labels for training set
train_y = snowdf_train.select("TOTAL_SALES").to_pandas()
test_x = snowdf_test.drop("TOTAL_SALES").to_pandas()
test_y = snowdf_test.select("TOTAL_SALES").to_pandas()
cat_cols = ['CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS']
num_cols = ['C_BIRTH_YEAR', 'CD_DEP_COUNT']

num_pipeline = Pipeline([
('imputer', SimpleImputer(strategy="median")),
('std_scaler', StandardScaler()),
])

preprocessor = ColumnTransformer(
transformers=[('num', num_pipeline, num_cols),
('encoder', OneHotEncoder(handle_unknown="ignore"), cat_cols) ])

pipe = Pipeline([('preprocessor', preprocessor),
('xgboost', XGBRegressor())])
pipe.fit(train_x, train_y)

test_preds = pipe.predict(test_x)
rmse = mean_squared_error(test_y, test_preds)
model_file = os.path.join('/tmp', 'model.joblib')
joblib.dump(pipe, model_file)
session.file.put(model_file, "@ml_models",overwrite=True)
return rmse

In this case, we are using sklearn for transformations and pipeline specification and XGBoost as our regressor. This function can now be registered and invoked via the Python API:

train_model_sp = F.sproc(train_model, session=session, replace=True)
# Switch to Snowpark Optimized Warehouse for training and to run the stored proc
session.use_warehouse('snowpark_opt_wh')
train_model_sp(session=session)

Now we’re ready to deploy this model to a UDF wherein it can be used for batch inference jobs.

import sys
import pandas as pd
import cachetools
import joblib
from snowflake.snowpark import types as T

session.add_import("@ml_models/model.joblib")

features = [ 'C_BIRTH_YEAR', 'CA_ZIP', 'CD_GENDER', 'CD_MARITAL_STATUS', 'CD_CREDIT_RATING', 'CD_EDUCATION_STATUS', 'CD_DEP_COUNT']

@cachetools.cached(cache={})
def read_file(filename):
import_dir = sys._xoptions.get("snowflake_import_directory")
if import_dir:
with open(os.path.join(import_dir, filename), 'rb') as file:
m = joblib.load(file)
return m

@F.pandas_udf(session=session, max_batch_size=10000, is_permanent=True,
stage_location='@ml_models', name="clv_xgboost_udf")
def predict(df: T.PandasDataFrame[int, str, str, str, str, str, int]) -> T.PandasSeries[float]:
m = read_file('model.joblib')
df.columns = features
return m.predict(df)

Invoking the model is quite simple and can be done in either Python or SQL.

inference_df = session.table('feature_store')
inference_df = inference_df.drop(['CUSTOMER_SK', 'C_CURRENT_HDEMO_SK', 'C_CURRENT_ADDR_SK', 'C_CUSTOMER_ID', 'CA_ADDRESS_SK', 'CD_DEMO_SK'])
inputs = inference_df.drop("TOTAL_SALES")
snowdf_results = inference_df.select(*inputs,
predict(*inputs).alias('PREDICTION'),
(F.col('TOTAL_SALES')).alias('ACTUAL_SALES')
)
snowdf_results.write.mode('overwrite').save_as_table('predictions')

The SQL equivalent of this, that one could use, and the above Python gets translated to is the following:

SELECT "C_BIRTH_YEAR", 
"CA_ZIP",
"CD_GENDER",
"CD_MARITAL_STATUS",
"CD_CREDIT_RATING",
"CD_EDUCATION_STATUS",
"CD_DEP_COUNT",
clv_xgboost_udf("C_BIRTH_YEAR", "CA_ZIP", "CD_GENDER", "CD_MARITAL_STATUS", "CD_CREDIT_RATING", "CD_EDUCATION_STATUS", "CD_DEP_COUNT") AS "PREDICTION",
"TOTAL_SALES" AS "ACTUAL_SALES"
FROM tpcds_xgboost.demo.feature_store

Now talking about performance, the entire inference pipeline(feature engineering+inference code above) on the 100 TB TPC-DS dataset can run in about 5.6 minutes on a 3XL warehouse. That’s ~$17 on Snowflake enterprise pricing. Plus, it’s all being done with no overhead, no VM’s to maintain, no cloud infrastructure to provision, no need to worry about autoscaling, patching, security group configuration, IAM policy configuration, and more. It’s all being done in secured managed environment for you.

As I tell all of my customers I work with, please do not take my word for it. Try it yourself, all of the code is available for you here. Even better than TPC-DS, take it for a test drive with some of your organizations data and pipelines.

--

--