Uncovering MLFlow’s Spark UDF

How it works under the hoods for isolated conda environments and what are the caveats

Yerachmiel Feltzman
Israeli Tech Radar
6 min readApr 24, 2023

--

Photo by Kevin Bidwell on Pexels

In our previous article, we discussed an excellent solution to avoid the ML dependencies syncing black-hole. Which is, we showed a way to avoid the challenge of syncing between the dependencies used during model training and those used during model serving, by having the model and the inference service running together but completely isolated.

The end of the conflicts. Peace on earth.

Photo by mali maeder on Pexels

In practical terms it meant three things:

  1. Models will be registered to MLFlow Model Registry with their training environment defined in a conda file;
  2. The inference will be made using Apache Spark;
  3. The prediction in Spark will be executed using MLFlow’s Spark UDF with an isolated conda environment.
ML inference service with Spark and MLFlow on environment isolation mode — image by the author

How does the magic happen? What are the caveats?

Note: This article jumps right away to deep diving into the UDF implementation. If you haven’t read the previous article yet, I highly suggest doing it. It explains what is the problem we are trying to solve and how we solve it using the steps above: “Avoid the ML dependencies syncing black-hole”.

Let’s uncover MLFlow’s Spark UDF!

How does MLFlow’s Spark UDF with conda environment work under the hood?

Let’s analyze mlflow.pyfunc.spark_udf's method signature:

def spark_udf(spark,
model_uri,
result_type="double",
env_manager=_EnvManager.LOCAL): ...

Spark and result_type are straightforward. So is model_uri. Let’s focus on env_manager. It will allow avoiding the ML model’s dependency syncing challenge:

Having the model and the inference service running together but completely isolated would mean we don’t need to solve the training and inference dependencies syncing challenge because the ML model can have whatever dependency it needs and the inference service can have whatever dependency it needs.

That’s what we stated in our previous article.

What will happen when we run mlflow.pyfunc.spark_udf with env_manager=”conda”?

predict_udf = mlflow.pyfunc.spark_udf(..., env_manager="conda")
return input_data.withColumn("prediction", predict_udf(struct("value")))
  1. The model will be loaded from the registry;
  2. It will be served as a Pandas UDF;
  3. A localhost scoring server — containing the loaded model from the registry — will be spun up using an isolated conda-environment in a child process;
  4. When called for inference, MLFlow’s Spark UDF will internally request the localhost scoring server and return its response.
Uncovering the UDF — image by the author

Diving into MLFLow’s source code

Let’s explore MLFLow’s source code (mlflow==2.2.2) for each of the above steps:

1. The model will be loaded from the registry

def spark_udf(spark, model_uri, result_type="double", env_manager=_EnvManager.LOCAL):
...
local_model_path = _download_artifact_from_uri(
artifact_uri=model_uri,
output_path=_create_model_downloading_tmp_dir(should_use_nfs),
)

...

model_metadata = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME))
...

2. It will be served as a Pandas UDF

def spark_udf(spark, model_uri, result_type="double", env_manager=_EnvManager.LOCAL):
# A LOT OF PREPARATION STEPS
...
@pandas_udf(result_type)
def udf(
iterator: Iterator[Tuple[Union[pandas.Series, pandas.DataFrame], ...]]
) -> Iterator[result_type_hint]: ...

@functools.wraps(udf)
def udf_with_default_cols(*args):
...
else:
return udf(*args)

return udf_with_default_cols

Note that what MLFLow's spark_udf actually returns is a pandas_udf after a lot of preparation work (as we will see) and after wrapping it with the udf_with_default_cols method for schema validation. Hence, when we call ourspark_udffor prediction, we are calling a pandas_udf.

3. A localhost scoring server will be spun up using an isolated conda-environment in a child process and use the loaded model from the registry

def spark_udf(spark, model_uri, result_type="double", env_manager=_EnvManager.LOCAL):
...
@pandas_udf(result_type)
def udf(...) -> Iterator[result_type_hint]:
...
if env_manager != _EnvManager.LOCAL:
...
pyfunc_backend.prepare_env(
model_uri=local_model_path_on_executor, capture_output=True
)
...
if check_port_connectivity():
# launch scoring server
server_port = find_free_port()
host = "127.0.0.1"
scoring_server_proc = pyfunc_backend.serve(
model_uri=local_model_path_on_executor or local_model_path,
port=server_port,
host=host,
timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(),
enable_mlserver=False,
synchronous=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)

client = ScoringServerClient(host, server_port)
...

4. When called for inference, MLFlow’s Spark UDF will internally request the localhost scoring server and return its response

def batch_predict_fn(pdf):
return client.invoke(pdf).get_predictions()

Is it so good to be true?

Caveats

As with everything in life, there are caveats we must be aware of. In this case, it’s worth pointing out two major considerations.

1. Serialization overhead

In order to call an internal localhost scoring server — the one that is running on a child process and listening to some port waiting to work — the spark_udf uses a standard REST request-response approach.

When the UDF is called, the batch_predict_fn inner function is called:

def batch_predict_fn(pdf):
return client.invoke(pdf).get_predictions()

Note that client.invoke is nothing more than a POST request pointing to 127.0.0.1:

def invoke(self, data):
response = requests.post(
url=self.url_prefix + "/invocations",
data=dump_input_data(data),
headers={"Content-Type": scoring_server.CONTENT_TYPE_JSON},
)
if response.status_code != 200:
raise Exception(
f"Invocation failed (error code {response.status_code}, response: {response.text})"
)
return PredictionsResponse.from_json(response.text)

And therefore dump_input_data is serializing the whole Spark partition data into a JSON string.

def dump_input_data(data, inputs_key="inputs"):
...
if not isinstance(post_data, str):
post_data = json.dumps(post_data, cls=_CustomJsonEncoder)
return post_data

2. Built for map-like ML models

Want to use it for clustering? Gotta think twice. I tried. It can work but needs ugly workarounds that will limit scale. Why?

Because MLFLow’s Spark UDF is a pandas_udf of type “Iterator of Multiple Series to Iterator of Series”. As such, it will map one or more column values to one column value, for each row. In a clustering use case, the Spark DataFrame would contain one or more subsets of data to be clustered, i.e., inputted to the model, which will return a cluster assigned to each row. Probably a better-fit solution for clustering use cases would be a pyspark.sql.GroupedData.applyInPandas method instead of the standard pandas_udf. MLFlow doesn’t ship this solution yet.

Map-like model prediction vs group-by-like model calculation — image by the author

To be fair, I’ll shortly point out two possible workarounds without diving into details (let me know in the comments if you actually want more details on how to actually workaround this problem, even though it’s not optimized for it):

a) group the data subset to be clustered in an array, so we have all data in one row using the collect_list function.

In this way we compress the whole subset into one row and fall back to the row-to-row approach it’s built for.

Problem: the list (ie, the whole dataset to be clustered), can be very long and will need to go through JSON serialization and an internal POST request.

b) split the Spark DataFrame into one DataFrame for each subset to be clustered, and pass them to the model one by one so it will work each call on the whole subset.

Problem: you will need to enforce zero-shuffling for each individual DataFrame, otherwise the model will end up getting only partial data. Moreover, you risk running Spark actions in sequential mode, losing parallelism.

Photo by Brett Jordan on Unsplash

MLFlow Spark UDF is a great solution for data-lead predictions running as Spark jobs. Coupled together with conda environments, it will allow isolating the model environment from the prediction service environment, saving us from the headache of syncing dependencies.

It achieves this by deploying a local scoring server in a child process inside Spark nodes, which in turn will be working inside an isolated conda environment.

Nevertheless, we should be aware of two possible drawbacks. First, the localhost scoring server is reached through local POST requests, demanding the partition to be serialized, adding overhead. Second, the map-like functionality of a pandas_udf, which is the basis for MLFlow Spark UDF implementation works awesome row-based inference but will need workarounds to work when groups of rows are needed for computation, like in the case of clusterings.

Happy (ML) pipeline! :)

--

--