MPS or MLX for Domestic AI? The Answer Will Surprise You

Mike Koypish
7 min readDec 23, 2023

--

Image generated by DALLE

Not so much time has elapsed since the introduction of a viable option for “local” deep learningMPS. This API, a sort of GPU driver, enables efficient neural network operations on Mac computers equipped with M-series GPUs. Leading frameworks like PyTorch and TensorFlow rapidly integrated MPS, and it has gained traction primarily for convenient laptop-based prototyping, rather than for substantial model training.

Might this seem like an unexpected move for a company known primarily for consumer-grade electronics? Apple, however, has been making strides forward. Amid the AI fervor of 2023, it unveiled a new and somewhat surprising development — a deep-learning framework named MLX. As of now, MLX is minimalistic, offering only the fundamental components necessary for constructing deep architecture. And it’s intended to work only on Apple’s chips.

At its core, MLX leverages MPS for tensor operations. Technically, comparing “MLX versus MPS” isn’t entirely accurate; a more appropriate comparison would be between MLX and “PyTorch-on-MPS”. Nonetheless, the former form has gained more popularity online, so I’ll adopt this nomenclature for our discussion.

Now, let’s put it to the test on a laptop equipped with an M1 Max chip!

Phi-2 LLM

We could simply train a standard MNIST classifier and benchmark performance on it, but why not make something cool — run some small LLM? Sorry for the unintended wordplay, but we just can’t handle large LLM on a memory-constrained laptop!

Just a week ago Microsoft released a new tiny LLM called Phi-2. It has only 2.7 billion parameters, so even 16GB of RAM will be enough. A (pre-)trained version is available at HF. Moreover, an implementation of Phi-2 was published as part of MLX examples, which simplifies our work. What we need to do is just download weights from HF and put them into an MLX-based model. We will run only inference (M1 Pro is still too slow for training), and compare performance to the original PyTorch-on-MPS implementation.

To clarify, we will not be using quantization here. While this technique allows to use even bigger model, it’s interesting now to run the model on regular half-precision arithmetics.

PyTorch on MPS

Let’s start with ready-to-use implementation from HF. All we need is just a handful of lines (you can find an extended version with measurements of performance here):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Check that MPS is available for PyTorch
assert torch.backends.mps.is_available()

torch.set_default_device("mps")

# Download model (with weights) and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/phi-2", trust_remote_code=True
)


# Default precision should be 16 bit
assert model.config.torch_dtype == torch.float16

# Tokenize text input
inputs = tokenizer('''What are the prime factors of 10?''',
return_tensors="pt", return_attention_mask=False)

assert inputs['input_ids'].device.type == 'mps', "Inputs should be on MPS"

# Inference happens here (and takes time/resources)
outputs = model.generate(**inputs, max_length=200)

# Decoding result
text = tokenizer.batch_decode(outputs)[0]
print(text)

Inference of 200 tokens takes 17.3 s ± 160 ms on my laptop. The loaded model occupies 6,5 GB of memory (which is consistent with 2.7 billion half-precision parameters) and consumes an additional 2.4 GB of common (out of MPS driver) memory during inference. But more on the benchmarking later, let’s now switch to MLX.

Phi-2 on MLX

Let’s quickly review the example implementation of Phi-2. Since MLX provides basic elements for NN, such as linear layers, layer normalization, popular activation functions (GELU in this case), embeddings layer, and even multi-head attention, it requires just a hundred lines of sparse code to exactly reconstruct Phi-2 architecture. In general, the code is similar to what you’d do with high-level PyTorch/Keras.

When it comes to the low-level aspect of MLX, its array interface is very similar to that of numpy. And it has a nice feature — you will not encounter hassles with devices in the code. Since on Apple silicon, both CPU and GPU have access to the same common RAM (unified memory architecture), you don’t need to think about placing tensors on the right device. Indeed, you can run an operation on tensor(s) to be executed on any device (either CPU or GPU), no matter where exactly tensors are located in memory. Just bringing an example from the docs, both operations will be executed successfully:

mx.add(a, b, stream=mx.cpu)
mx.add(a, b, stream=mx.gpu)

Why do we pay attention to this nuance in this post? Just because theoretically it could be used for some optimization: employ CPU to execute some operations in parallel to GPU. Although barely likely it could give a big boost, let’s just note this so far, and will come back later when analyzing benchmark results.

We shall now look into performance. My code with all the measurements is located here. Brief results — inference takes only 6.56 s ± 14.6 ms, which is around 3 times faster!

Performance benchmarking

Let’s quickly review how the measurement was conducted.

For RAM consumption, we can use several gauges:

import gc
import torch
import psutil


print(
"Tensors on MPS (excluding some cached allocations):",
torch.mps.current_allocated_memory()
)
print("Total on MPS driver:", torch.mps.driver_allocated_memory())

print("Allocated for current process:",
psutil.Process().memory_info().rss
)
print("Total system memory used",
psutil.virtual_memory().total - psutil.virtual_memory().available
)

# Invoke this before/after operation of interest,
# for cache/garbage not influencing our memory mesurements
def empty_cache():
gc.collect()
torch.mps.empty_cache()
print('Cache emptied: python (GC) and MPS ')

Additionally, we should explicitly invoke Python’s garbage collector and clean the MPS cache, so they won’t distort our results. In addition, we employ memory-profiler with its %%memit ipython extension to analyze peak memory consumption.

Depending on how you will measure, both approaches consume pretty similar amounts of memory (around 6 GB for the model itself, and an additional 2–3 GB during inference). But memory is quite tricky to measure correctly, taking into account quite a complicated case (extensive use of multiprocessing, and MPS driver). Moreover, it’s not very important, we’re rather interested in the speed.

Let’s take a look at speed results, which are fairly impressive (for a third competitor slow CPU-based PyTorch was added):

Seems like a huge win for MLX. Around 3 times faster!

But the question arises, what could be the reason for such a significant improvement? Some crazy efficient usage of Apple’s GPU? Why PyTorch can’t do the same with MPS device? Maybe MLX carefully leverages specific GPU architecture, while PyTorch’s compilers are optimized for NVidia? Maybe something with unified memory architecture?

To find the answer, we shall look at what happens with CPU/GPU utilization. We will use powermetrics and analyze its output with some code borrowed from the amazing asitop package:

sudo nice -n 10 powermetrics --samplers cpu_power,gpu_power,thermal -o powermetrics.txt -f plist -i 1000

Below is PyTorch on MPS execution. As you see, GPU utilization seems almost perfect — more than 90% all the time during inference.

And this is for MLX, which is 3 times faster. GPU active time percentage is even slightly lower than PyTorch — 89%.

Do you spot the difference? Right, GPU frequency! PyTorch’s frequency remains below 400 MHz, while MLX runs at 1200 MHz. Moreover, this is exactly 3 times faster, the factor we encountered while comparing execution times. This confirms our guess and allows us to conclude, that the source of MLX's higher performance lies in hardware, not in some clever software optimizations.

We found what speeds up MLX, and this is quite surprising. Why PyTorch can’t do the same with GPU? It’s well-known that Apple’s GPU has a changing frequency, and when not used it should go to low-energy mode by lowering frequency. Maybe it doesn’t automatically speed up during extensive usage but rather is supposed to explicitly ask for a boosted mode. Maybe PyTorch is unaware of boosted modes, or the current MPS API doesn’t support this? I’m not an expert here and can’t answer this question. I was unable to find any relevant information on the internet, except the same issue with low frequency was reported with TensorFlow. In any case, we could confidently say, that PyTorch’s defeat is rather episodical, and slowness seems easy to fix.

Summary

Finally, let’s recap the results of our journey:

  • We experimented with MLX, a new deep-learning framework designed for Mac computers, by executing the state-of-the-art tiny LLM Phi-2 model.
  • In comparing the inference speed of the MLX-based model with that of a PyTorch model on the same Apple M1-Pro GPU, we discovered that MLX is three times faster.
  • Such a big speed-up is pretty suspicious. Is it due to some smart optimizations specific to this GPU architecture? Or maybe clever usage of unified memory brings some boost?
  • After a detailed analysis of GPU usage, we have found an answer, and it will surprise you: for reasons yet to be understood, the GPU operates at a lower (probably, minimal possible) frequency when running PyTorch. So, the reason is not some smart MLX optimizations, instead PyTorch just doesn’t use hardware “engine” at its full power. And if it would run GPU at the same maximum frequency that MLX does, performance (tokens/second) is expected to be pretty the same.
  • The benchmark full code is located here

--

--