Performing distributed predictions at scale with Snowpark on PySpark models using ONNX

Problem Statement

Snowpark for Python, an incredibly price-performant open-source python library (pip | conda) built by Snowflake, is enabling Snowflake customers to adopt architecturally-simple data engineering and data science pipelines on Snowflake’s managed cloud. Snowpark for Python removes the need to manage, optimize or troubleshoot clusters to process vast amounts of data and allows customers to focus on their pipelines that drive business insights.

We have several customers who are using PySpark for model training, however, their data is in Snowflake to maintain a single source of truth for data governance and enable data sharing with their customers at scale. Specifically, one customer asked how a model trained with SynapseML LightGBMClassifier on PySpark can be brought into Snowflake using Snowpark Python to perform distributed batch predictions and keep the model close to where the data is.

In a nutshell, the answer was exporting the PySpark model into Open Neural Network Exchange (ONNX) format and bringing it into Snowflake as a Snowpark Vectorized UDF. This Snowpark Vectorized UDF would use onnxruntime library already available in Snowflake along with 1000+ other packages within our native Snowpark Anaconda integration, load the ONNX model only once into cache using cachetools also available in Snowflake and perform distributed batch predictions. This is shown in Figure 1.

Figure 1: End to End architecture

The rest of this blog shows the end to end process on how this was achieved.

End to End implementation

This section is split up into five subsections and starts with Section Step1: Push-down preprocessing/feature engineering with Snowpark for medium blog brevity. However, if you are interested in seeing how the CSV data was loaded into Snowflake using Snowpark for Python, see here. For this example, a sample public dataset on health strokes is chosen, on which depending upon several predictors, we will attempt to predict if a patient is likely to have a stroke or not. The CSV data can be seen here. The full source code for this example is available on GitHub here.

Step0: Prerequisites

The following libraries shown in Figure 2 were installed on the Databricks cluster runtime 10.4 with python 3.8. Except model training step Step2: Model Training with PySpark, all other Steps 1–5 require a single node Databricks driver and worker node with minimal configuration because all the work is done by Snowpark pushed down into Snowflake. Hence, the Databricks cluster configuration required in this setup is restricted only to the setup needed for training the PySpark model in Step2: Model Training with PySpark below.

Figure 2: Libraries installed on PySpark cluster

Step1: Push-down preprocessing/feature engineering with Snowpark

The csv data loaded into Snowflake contains columns ID, GENDER, AGE, HYPERTENSION, HEART_DISEASE, EVER_MARRIED, WORK_TYPE, RESIDENCE_TYPE, AVG_GLUCOSE_LEVEL, BMI, SMOKING_STATUS and STROKE. A preview of data in the Snowflake table is shown below in Figure 3.

Figure 3: Preview of the raw data in Snowflake table

As seen in Figure 3, the columns GENDER, EVER_MARRIED, WORK_TYPE, RESIDENCE_TYPE and SMOKING_STATUS contained categorical data. We will label- encode these columns pushed down into Snowflake with Snowpark python. Additionally, we will perform missing data management as well as SMOTE to address highly imbalanced data based on the predictor variable STROKE.

Step1(a): Push-down label encoding of categorical data

In order to label encode the categorical variables GENDER, EVER_MARRIED, WORK_TYPE, RESIDENCE_TYPE and SMOKING_STATUS, we used a few utilities available within “preprocessing” folder available here which were created by one of my colleagues in my team. The process to label-encode the categorical variables and the resulting output is shown in Figure 4. The entire notebook is available here on github.

Figure 4: Label-encoding categorical variables with Snowpark Python

Step1(b): Missing data management and data filtering

For this modeling exercise, only those patients were considered who were above the age of 2. Moreover, missing data was dropped. This process is shown in Figure 5. The entire notebook is available here on github.

Figure 5: Missing data management and data filtering with Snowpark Python

Step1(c): Synthetic Minority Oversampling Technique (SMOTE)

Out of the resulting 29,072 records; we had 28,524 records with no strokes and only 548 records with a stroke. Since our dataset was highly imbalanced, SMOTE technique was applied on the dataset as a Snowpark Stored Procedure pushed down into Snowflake. Before executing the SMOTE stored procedure, Snowpark Optimized Warehouses in public preview could be used as optional in cases where there are huge datasets that need to be brought into memory. The stored procedure that performs SMOTE on our dataset within a Snowpark Stored Procedure is shown in Figure 6. More details about the imblearn package are available here. The entire notebook is available here on github.

Figure 6: SMOTE with Snowpark Python Stored Procedure

Once the data is resampled within the Snowpark Python Stored Procedure, we end up with 28,524 records with no stroke and 28,524 records with a stroke. This resampled data can then be split up into 80:20 training:testing ratio and saved into Snowflake as shown in Figure 7.

Figure 7: Re-balanced dataset saved as training and testing tables within Snowflake

Step2: Model Training with PySpark

In order to train the model in PySpark, we would need a PySpark dataframe on top of the feature engineering data that was pushed into a Snowflake table by Snowpark Python. For this, we can use the inbuilt Snowflake Spark connector within Databricks. The columns with data type decimal(38,0) in PySpark were also converted to integers before model training. This is shown in Figure 8.

Figure 8: Feature Engineered dataset retrieval in PySpark for Model Training

Once the data is available as a PySpark dataframe, the PySpark dataframe is run through a VectorAssembler transformation and then Microsoft’s SynapseML library is used to train a LightGBMClassifier model to predict if a stroke would occur or not. This is shown in Figure 9. The entire notebook is available here on github.

Figure 9: LightBGMClassifier model trained in PySpark

Step3: Export to ONNX model

ONNX is an open format for ML models, allowing us to interchange models between various ML frameworks and tools.Our trained LightGBMClassifier model is converted into ONNX format using the ONNXMLTools python package as shown in Figure 10. The entire notebook is available here on github.

Figure 10: Exporting the model into ONNX format

Step4: Snowpark Vectorized UDF that brings in ONNX model

With the ONNX model on the databricks driver, we can simply import the ONNX model into a Snowflake stage and create a vectorized UDF using the Snowpark Python UDF Batch API. The custom import into vectorized UDF is possible due to the “imports” option when the UDF is registered with Snowpark. This custom import takes the ONNX model from Databricks driver and puts it into a special directory ‘snowflake_import_directory’ within python secure sandbox in Snowflake. The Snowpark Vectorized UDF loads the model using the onnxruntime library from the ‘snowflake_import_directory’ once and keeps it in cache due to the usage of the cachetools library. All the libraries needed for this Snowpark Vectorized UDF are available in Snowflake conda channel as listed here. The process to register the Snowpark Vectorized UDF and model loading into cache with onnxruntime library is shown in Figure 11. The entire notebook is available here on github.

Figure 11: Snowpark Vectorized UDF that loads ONNX model into memory once and perform distributed batch predictions

Once the Snowpark Vectorized UDF is registered, this auto-creates the user function in Snowflake, along with the ONNX model imported as a dependency as shown in Figure 12.

Figure 12: Snowpark Vectorized UDF available as a user function in Snowflake with ONNX model as a dependency

Step5: Native ONNX predictions within Snowflake

With the Snowpark Vectorized UDF available as a user function within Snowflake, we can simply call the vectorized UDF passing in the feature columns and save the results in a Snowflake table. The Snowpark Vectorized UDF will bring in the imported ONNX model and perform distributed batch predictions right within Snowflake. This is shown in Figures 13 and 14. The entire notebook is available here on github.

Figure 13: Calling Snowpark Vectorized UDF to make distributed batch predictions in Snowflake
Figure 14: Saving Snowpark dataframe containing predictions as a Snowflake table

In this example, performing batch predictions on 11.4K rows took 2.58 seconds on a small Snowflake warehouse. This invocation of the ONNX model within Snowflake is shown in Figure 15.

Figure 15: Invocation of SCORE_ONNX_MODEL for predicting natively within Snowflake

We can also define functions to calculate classification metrics such as accuracy, F1 score, precision and recall, as well as create the confusion matrix on top of the scored snowpark dataframe containing actual STROKE and the PREDICTED_STROKE columns.

The computed classification metrics and confusion matrix are shown in Figure 16.

Figure 16: Classification metrics and confusion matrix

Conclusion

This blog shows the flexibility of the open-source Snowpark for Python library (pip | conda), which can be installed anywhere there is a python kernel available with python 3.8 (support for 3.9+ coming in near future), can be used for preprocessing/feature engineering of data pushed down into Snowflake and bring in models trained in PySpark and converted into ONNX as vectorized UDFs to perform batch predictions at scale. With this architectural pattern, customers who have invested in PySpark based machine learning models can easily use Snowpark at scale for their data science workloads. For more information on Snowpark for Python, refer to the Snowpark Developer Guide for Python, or try out one of our QuickStarts.

--

--