Easy storage of custom Spark ML models

Max Baak
5 min readMay 15, 2024

--

ING bank’s SparkWritable/SparkReadable classes are now open-source.
Go try out our solution!

TL;DR

At ING Wholesale Banking Advanced Analytics we have open-sourced Python code to simplify the storage and loading of custom Spark ML models. (De-)Serialisation of a custom Spark ML model is now as simple as inheriting from two extra classes, and indicating which class attributes need to be persisted and loaded from disk.

The problem we solve

Anyone who works on a Spark cluster and applies ML models to Spark dataframes may know the problem: the persistence of custom Spark ML models is non-trivial. Although Spark supports the storage of classes of Spark’s machine learning library (MLlib), dataframes and model parameters, the storage of custom Spark objects turns out to be complex — for example if an object contains custom attributes like model settings or sklearn objects.

In more detail, the Python bindings for Spark are not written as idiomatic Python, rather they use Java-style design patterns with Python classes. This in combination with missing documentation prevents a straight-forward definition of custom Spark ML serialization. There is no simple, generic solution for this that we’re aware of.

As a work-around one may revert to storing all non-Spark objects separately, and reinitialise the Spark object manually in each new session. But this is cumbersome for constructs of objects and is clearly not ideal.

At ING Wholesale Banking Advanced Analytics we have developed Python code to make the storage and reading of custom ML models much easier. Two classes, SparkWritable and SparkReadable, solve the problem of storing and loading custom Spark ML models of arbitrary complexity.

These classes can be found here (they are part of ING’s Entity Matching Model Python package). Let us show you how to use them.

Example use

As a quick example, one can do:

pip install -U emm[spark]
import sklearn

from pyspark.ml import Transformer
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql import DataFrame

from emm.helper.spark_custom_reader_writer import SparkReadable, SparkWriteable

# Example custom Spark transformer class with non-Spark attributes.
# the class needs to inherit from SparkReadable, SparkWriteable, DefaultParamsReadable, DefaultParamsWritable

class MySparkTransformer(
Transformer, SparkReadable, SparkWriteable, DefaultParamsReadable, DefaultParamsWritable
):
"""My Spark Transformer"""

# These are the attributes that are serialized. Here "a", "b" and "c"
SERIALIZE_ATTRIBUTES = (
"a",
"b",
"c",
)

def __init__(
self,
a: str = "score",
b: list = None,
c: sklearn.pipeline.FeatureUnion = None,
) -> None:
super().__init__()
# make sure that init() accepts the "a", "b" and "c" arguments.
self.a = a
self.b = b
self.c = c

def _transform(self, dataframe: DataFrame) -> DataFrame:
# use self.a, self.b and self.c here somewhere!
return dataframe

To store and read back, simply do:

# storage
model1 = MySparkTransformer(a="foo", b=[0,1,2,3,4,5,6,7])
model1.save("mymodel")

# loading from disk
model2 = MySparkTransformer.load("mymodel")
print(model2.a, model2.b, model2.c)

The example stores the custom model in the local directory mymodel/.

That’s all there’s to it!

All you need to know (to use it)

Only a couple of things!

  • Your class needs to inherit from the classes: SparkReadable, SparkWritable, DefaultParamsReadable, DefaultParamsWritable. The last two — commonly used for Spark parameters — add load() and save() functions to your class.
  • All class attributes that need to be serialised when calling save() need to specified in SERIALIZE_ATTRIBUTES. If the current spark session is also needed in your class, specify the attribute in SPARK_SESSION_KW.
  • Each of these need to be passable to the __init__(attrA, attrB, ..., spark_session_kw) function of the class, in which one sets: self.attrA = attrA, self.spark_session_kw = spark_session_kw, etc. This is needed to call load() successfully.

Under the hood

All underlying magic is contained in the SparkWritable and SparkReadable classes. Three attribute types are recognised for storage and reading; together these cover most objects in practice.

  • Does the object inherit from DefaultParamsWritable or MLWritable? (All Spark ML classes do.) In which case that object’s save() and load() functions are used.
  • Is the object a Spark dataframe? In which case the dataframe’s save() and load() functions are used.
  • Any other object which can be dumped to a file and loaded from it, using (by default) the joblib library, in joblib’s compressed binary format.

We stress this also allows for the storage of any constructs of Spark objects, such as Pipelines containing objects of custom Spark classes.

Storage to S3

Writing and reading to/from S3 needs dedicated (access) functions for all non-Spark attributes. Examples of S3 write_json() and read_json() functions are given in the Appendix below. Note that these may be cluster specific.

Given such read and write functions, one can configure these for all custom Spark objects in one go with the singleton IOFunc class:

from emm.helper.io import IOFunc

IOFunc().writer = write_json
IOFunc().reader = read_json

They are then picked up by all SparkWritable and SparkReadable objects.

Conclusion

At ING Wholesale Banking Advanced Analytics we have open-sourced Python code to simplify the storage and loading of custom Spark ML models. This is (currently) contained in the Entity Matching Model library emm, and can be easily picked up (or copied) by anyone from there. We invite you to try it out and are happy to hear your feedback!

Contributors

This code was authored by ING Analytics Wholesale Banking, in particular: Max Baak, Simon Brugman and Tomasz Waleń. Thanks to Nikoletta Bozika for reviewing this blog.

Appendix

Example functions of read_json() and write_json() from/to S3.

import json
import os
from pathlib import Path
from typing import Any
import boto3

def read_json(fn: str | Path) -> Any:
"""Reading JSON file (both local & S3 files supported)"""
if str(fn).startswith("s3:"):
return json.loads(read_s3_file(fn))
# handling regular files
with open(fn) as f:
return json.load(f)

def write_json(fn: str | Path, data: Any, **kwargs) -> None:
"""Writing JSON file (both local & S3 files supported)"""
if str(fn).startswith("s3:"):
write_s3_file(fn, json.dumps(data, **kwargs))
else:
# handling regular files
with open(fn, "w") as f:
json.dump(data, f, **kwargs)

def read_s3_file(fn: str, client=None) -> bytes:
bucket, key = parse_s3_path(fn)
s3_client = init_s3_client() if client is None else client
response = s3_client.get_object(Bucket=bucket, Key=key)
return response["Body"].read()

def write_s3_file(fn: str, contents: str) -> None:
bucket, key = parse_s3_path(fn)
s3_client = init_s3_client()
s3_client.put_object(Bucket=bucket, Body=contents, Key=key)

def parse_s3_path(fn: str) -> tuple[str, str]:
tmp = fn.split("://")[1].split("/")
return tmp[0], "/".join(tmp[1:])

def init_s3_client():
AWS_S3_ENDPOINT_URL = os.getenv("AWS_S3_ENDPOINT_URL")
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_SESSION_TOKEN = os.getenv("AWS_SESSION_TOKEN")
return boto3.client(
"s3",
endpoint_url=AWS_S3_ENDPOINT_URL,
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_session_token=AWS_SESSION_TOKEN,
)

--

--

Max Baak

Data scientist, Researcher, former Particle Physicist