MLX/Pytorch speed analysis on MacBook Pro M3 Max

Istvan Benedek
12 min readMar 5, 2024

--

Two months ago, I got my new MacBook Pro M3 Max with 128 GB of memory, and I’ve only recently taken the time to examine the speed difference in PyTorch matrix multiplication between the CPU (16 cores) and GPU (40 cores). The outcomes varied significantly depending on the system’s load, but one thing can be definitively stated: the multiplication runs at an astonishing speed on the GPU compared to the CPU…

With the introduction of Metal support for PyTorch on MacBook Pros, leveraging the GPU for machine learning tasks has become more accessible, offering a pathway to utilize the advanced capabilities of Apple Silicon GPUs. This development marks a significant step towards optimizing AI and machine learning workflows on macOS, harnessing the power of Metal’s low-level graphics and compute API to accelerate operations.

python code at https://github.com/CAG12/python/blob/main/pythorch.ipynb

Type limitation

However, a notable limitation emerges with the constraint to using float32 data type only for matrix multiplication (maybe also valid for elsewhere). In scientific computing and machine learning, precision in calculations is paramount. The float32 type, while sufficient for many applications, may not offer the same level of precision as float64 (double precision), potentially impacting the accuracy of results in certain high-precision tasks. This restriction requires careful consideration of the trade-off between performance gains through GPU acceleration and the precision requirements of specific applications.

While integrating Metal support in PyTorch for MacBook Pros is a significant advancement, it also highlights the limitation of not supporting integer data types such as int32 or int64 for computations. This restriction poses additional challenges for developers and researchers, especially when working on tasks that require integer arithmetic or inherently integer-based operations, such as indexing or certain types of discrete data processing.

The absence of int32 and int64 support necessitates finding workarounds or alternative methods to perform tasks that would typically benefit from or require these data types. For instance, operations that rely on integer math for performance optimizations or memory efficiency may need to be redesigned using float32 types, potentially leading to inefficiencies or precision issues.

Moreover, this limitation impacts the applicability of GPU acceleration for a broader range of applications. Tasks such as graph algorithms, certain machine learning models that rely on discrete data, or operations that involve heavy indexing might not fully leverage the GPU’s capabilities, leading to missed opportunities for performance gains.

Addressing these challenges requires creative problem-solving from the community and possibly future updates to the framework to support a broader range of data types on GPUs. As the ecosystem evolves, ongoing developments may eventually bridge these gaps, offering more versatile and powerful tools for AI and machine learning on Apple Silicon-powered MacBooks.

Adapting to this limitation involves adjusting workflows to ensure that the use of float32 does not adversely affect computation outcomes. For developers and researchers, this may mean altering algorithms to be more robust against precision limitations or incorporating checks to validate the integrity of results when using the GPU for computation. Despite these challenges, including Metal support is a promising advancement, pushing the boundaries of what’s possible with machine learning on MacBooks.

Update

I’ve just found a very comprehensive benchmark here at medium.com (https://medium.com/towards-data-science/how-fast-is-mlx-a-comprehensive-benchmark-on-8-apple-silicon-chips-and-4-cuda-gpus-378a0ae356a0) So, I performed the test against my M3 MAX, and here you can see the results:

| Operation                                           | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup |
|-----------------------------------------------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------|
| Argmax / dim=64x1024x128 axi=0 | 1.68 | 1.54 | 9.03 | 0.74 | 20.95 | +9% | -55% | +438% |
| Argmax / dim=64x1024x128 axi=1 | 1.63 | 1.54 | 9.01 | 0.82 | 1.87 | +5% | -50% | +451% |
| Argmax / dim=64x1024x128 axi=2 | 1.63 | 1.62 | 8.92 | 0.72 | 1.41 | +0% | -55% | +447% |
| Argmax / dim=64x128x1024 axi=2 | 1.64 | 1.70 | 9.00 | 0.54 | 1.23 | -3% | -67% | +449% |
| BCE / dim=1000000 dim=1000000 | 0.24 | 0.49 | 6.16 | 0.28 | 0.74 | -51% | +15% | +2474% |
| BCE / dim=100000x32 dim=100000x32 | 0.50 | 0.26 | 19.94 | 0.35 | 1.85 | +90% | -30% | +3899% |
| BCE / dim=100000x64x2 dim=100000x64x2 | 1.84 | 0.77 | 80.59 | 0.79 | 6.44 | +139% | -57% | +4268% |
| BCE / dim=128x100000 dim=128x100000 | 1.83 | 0.83 | 81.05 | 0.53 | 6.60 | +120% | -71% | +4332% |
| Concat / dim=1000000x64 dim=1000000x32 axi=1 | 2.44 | 2.52 | 68.44 | 2.44 | 16.72 | -3% | 0% | +2706% |
| Concat / dim=1000000x64 dim=1000000x128 axi=1 | 4.60 | 4.64 | 193.54 | 4.69 | 36.77 | 0% | +1% | +4107% |
| Concat / dim=1000000x64 dim=1000000x64 axi=0 | 3.14 | 3.16 | 46.02 | 3.17 | 21.33 | 0% | +1% | +1367% |
| Concat / dim=64x1000000 dim=64x1000000 axi=0 | 3.16 | 3.13 | 68.77 | 3.17 | 20.68 | +0% | +0% | +2079% |
| Conv1d / dim=100x256x3 dim=8x3x3 | 0.39 | 0.25 | 0.34 | 0.28 | 2.46 | +54% | -26% | -13% |
| Conv1d / dim=100x256x256 dim=8x3x256 | 1.12 | 1.05 | 6.08 | 0.64 | 70.01 | +6% | -42% | +442% |
| Conv1d / dim=16x1000x80 dim=128x11x80 | 1.09 | 1.00 | 2.93 | 0.89 | 508.60 | +9% | -18% | +167% |
| Conv1d / dim=16x1000x3 dim=128x11x3 | 0.51 | 0.30 | 0.43 | 0.58 | 50.19 | +69% | +13% | -15% |
| Conv2d / dim=100x256x256x3 dim=8x3x3x3 | 2.57 | 2.57 | 802.52 | 2.16 | 117.42 | +0% | -16% | +31155% |
| Conv2d / dim=10x256x256x12 dim=8x3x3x12 | 4.23 | 4.27 | 325.15 | 1.03 | 13.20 | 0% | -75% | +7583% |
| Conv2d / dim=1x256x256x128 dim=8x3x3x128 | 0.43 | 0.45 | 501.10 | 0.97 | 26.96 | -3% | +122% | +115163% |
| Conv2d / dim=100x28x28x3 dim=8x3x3x3 | 0.21 | 0.21 | 9.32 | 0.34 | 1.49 | 0% | +61% | +4371% |
| Conv2d / dim=1000x28x28x3 dim=8x3x3x3 | 0.55 | 0.50 | 88.47 | 0.63 | 7.47 | +10% | +14% | +15906% |
| Gather / dim=64x256 dim=10 | 0.13 | 0.17 | 0.01 | 0.20 | 0.00 | -19% | +47% | -91% |
| Gather / dim=64x256 dim=1000 | 0.17 | 0.18 | 0.03 | 0.29 | 0.12 | -5% | +72% | -83% |
| Gather / dim=64x256 dim=1000000 | 5.00 | 4.95 | 16.05 | 41.62 | 38.33 | +0% | +732% | +221% |
| Gather / dim=1024x32 dim=10 | 0.17 | 0.13 | 0.01 | 0.20 | 0.00 | +29% | +13% | -95% |
| Gather / dim=1024x32 dim=1000 | 0.17 | 0.14 | 0.01 | 0.22 | 0.10 | +25% | +28% | -92% |
| Gather / dim=1024x32 dim=1000000 | 0.86 | 0.81 | 5.12 | 5.31 | 4.76 | +5% | +518% | +496% |
| LeakyReLU / dim=128x16x1024 | 0.25 | 0.20 | 0.40 | 0.26 | 0.39 | +23% | +5% | +60% |
| LeakyReLU / dim=64x128x1024 | 0.33 | 0.32 | 1.02 | 0.47 | 1.58 | +3% | +40% | +207% |
| Linear / dim=100x1024x32 dim=32x1024 dim=1024 | 4.06 | 4.04 | 16.93 | 1.56 | 49.45 | +0% | -61% | +316% |
| Linear / dim=100x1024x64 dim=64x1024 dim=1024 | 4.24 | 4.19 | 20.02 | 2.06 | 76.48 | +1% | -51% | +372% |
| Linear / dim=100x1024x256 dim=256x1024 dim=1024 | 7.47 | 7.41 | 31.45 | 5.70 | 107.12 | +0% | -23% | +321% |
| Linear / dim=100x1024x512 dim=512x1024 dim=1024 | 11.99 | 11.99 | 47.44 | 10.47 | 150.67 | +0% | -12% | +295% |
| Linear / dim=100x1x51200 dim=51200x1 dim=1 | 0.54 | 0.47 | 0.20 | 0.58 | 10.20 | +14% | +7% | -63% |
| MatMul / dim=32x1x1000 dim=32x1000x128 | 0.17 | 0.19 | 0.07 | 0.33 | 0.65 | -12% | +95% | -57% |
| MatMul / dim=1000x64x256 dim=256x32 | 0.58 | 0.41 | 1.30 | 0.91 | 14.86 | +40% | +57% | +124% |
| MatMul / dim=1000x64x1024 dim=1000x1024x32 | 1.34 | 1.34 | 9.74 | 1.57 | 354.48 | +0% | +17% | +628% |
| MatMul / dim=1000x1024x64 dim=1000x64x256 | 4.62 | 4.63 | 39.16 | 5.17 | 887.66 | 0% | +11% | +746% |
| MatMul / dim=64x1000000 dim=1000000x32 | 2.78 | 2.76 | 7.37 | 3.79 | 68.34 | +0% | +36% | +165% |
| MatMul / dim=1000000x64 dim=64x1024 | 15.98 | 15.77 | 83.70 | 32.40 | 1740.92 | +1% | +102% | +423% |
| PReLU / dim=128x16x1024 dim=1 | 0.29 | 0.25 | 1.53 | 0.26 | 0.45 | +18% | -11% | +421% |
| PReLU / dim=64x128x1024 dim=1 | 0.38 | 0.34 | 4.49 | 0.38 | 1.54 | +11% | 0% | +1068% |
| ReLU / dim=128x16x1024 | 0.44 | 0.22 | 0.31 | 0.30 | 0.40 | +102% | -33% | -29% |
| ReLU / dim=64x128x1024 | 0.34 | 0.33 | 0.62 | 0.51 | 1.48 | +3% | +47% | +81% |
| Scatter / dim=64x16 dim=10 | 0.13 | 0.17 | 0.01 | 0.15 | 0.00 | -25% | +14% | -91% |
| Scatter / dim=64x16 dim=1000 | 0.14 | 0.18 | 0.07 | 0.15 | 0.06 | -21% | +10% | -47% |
| Scatter / dim=64x16 dim=1000000 | 0.35 | 0.33 | 52.84 | 2.61 | 2.29 | +5% | +653% | +15191% |
| Scatter / dim=1024x32 dim=10 | 0.16 | 0.17 | 0.01 | 0.15 | 0.00 | -4% | -6% | -91% |
| Scatter / dim=1024x32 dim=1000 | 0.16 | 0.17 | 0.12 | 0.13 | 0.06 | -3% | -16% | -27% |
| Scatter / dim=1024x32 dim=1000000 | 0.51 | 0.50 | 99.37 | 5.20 | 3.02 | +2% | +923% | +19456% |
| ScatterSum / dim=64x16 dim=10 | 0.04 | 0.03 | 0.01 | nan | 0.00 | +28% | nan% | -74% |
| ScatterSum / dim=64x16 dim=1000 | 0.03 | 0.03 | 0.01 | nan | 0.00 | +11% | nan% | -71% |
| ScatterSum / dim=64x16 dim=1000000 | 0.03 | 0.03 | 0.01 | nan | 1.19 | +8% | nan% | -70% |
| ScatterSum / dim=1024x32 dim=10 | 0.03 | 0.03 | 0.01 | nan | 0.01 | +5% | nan% | -70% |
| ScatterSum / dim=1024x32 dim=1000 | 0.03 | 0.03 | 0.01 | nan | 0.01 | +8% | nan% | -71% |
| ScatterSum / dim=1024x32 dim=1000000 | 0.03 | 0.03 | 0.01 | nan | 6.18 | +0% | nan% | -70% |
| ScatterMax / dim=64x16 dim=10 | 0.03 | 0.03 | 0.01 | nan | 0.00 | +1% | nan% | -71% |
| ScatterMax / dim=64x16 dim=1000 | 0.03 | 0.03 | 0.01 | nan | 0.00 | +6% | nan% | -71% |
| ScatterMax / dim=64x16 dim=1000000 | 0.03 | 0.03 | 0.01 | nan | 1.17 | +3% | nan% | -71% |
| ScatterMax / dim=1024x32 dim=10 | 0.03 | 0.03 | 0.01 | nan | 0.01 | +3% | nan% | -71% |
| ScatterMax / dim=1024x32 dim=1000 | 0.03 | 0.03 | 0.01 | nan | 0.01 | +8% | nan% | -67% |
| ScatterMax / dim=1024x32 dim=1000000 | 0.03 | 0.03 | 0.01 | nan | 6.37 | +1% | nan% | -62% |
| SeLU / dim=128x16x1024 | 0.45 | 0.30 | 2.17 | 0.29 | 1.37 | +50% | -35% | +378% |
| SeLU / dim=64x128x1024 | 0.36 | 0.33 | 6.38 | 0.50 | 5.42 | +7% | +39% | +1675% |
| Sigmoid / dim=128x16x1024 | 0.18 | 0.21 | 1.50 | 0.40 | 1.32 | -15% | +119% | +723% |
| Sigmoid / dim=64x128x1024 | 0.36 | 0.33 | 6.22 | 0.64 | 4.95 | +9% | +75% | +1609% |
| Softmax / dim=64x1000000 axi=-1 | 5.81 | 4.41 | 41.60 | 3.11 | 22.36 | +31% | -46% | +615% |
| Softmax / dim=1000000x64 axi=-1 | 5.97 | 4.38 | 43.13 | 3.94 | 21.45 | +36% | -34% | +622% |
| Softmax / dim=64x16x32x1024 axi=-1 | 3.52 | 2.41 | 21.90 | 1.78 | 9.80 | +46% | -49% | +522% |
| Softmax / dim=128x16x32x1024 axi=-1 | 6.33 | 4.58 | 43.89 | 3.44 | 22.96 | +38% | -45% | +593% |
| Softmax / dim=1024x16x32x128 axi=-1 | 6.05 | 4.59 | 42.80 | 4.07 | 25.29 | +31% | -32% | +607% |
| Softmax / dim=1024x64x32x8 axi=-1 | 1.66 | 1.31 | 10.63 | 1.35 | 12.85 | +26% | -18% | +538% |
| Softplus / dim=128x16x1024 | 0.23 | 0.21 | 11.44 | 0.32 | 1.74 | +8% | +39% | +4890% |
| Softplus / dim=64x128x1024 | 0.32 | 0.34 | 44.75 | 0.50 | 6.96 | -4% | +56% | +13890% |
| Sort / dim=64x128x1024 axi=0 | 0.74 | 0.73 | 228.52 | 8.83 | 62.04 | +2% | +1086% | +30626% |
| Sort / dim=64x128x1024 axi=1 | 0.73 | 0.74 | 227.92 | 8.48 | 33.13 | -1% | +1054% | +30913% |
| Sort / dim=64x128x1024 axi=2 | 0.74 | 0.75 | 227.72 | 6.25 | 32.30 | -1% | +745% | +30706% |
| Sum / dim=64x128x128x128 axi=0 | 1.55 | 1.54 | 6.50 | 1.65 | 8.85 | +0% | +6% | +320% |
| Sum / dim=64x128x128x128 axi=1 | 1.54 | 1.55 | 6.51 | 1.70 | 7.59 | 0% | +10% | +322% |
| Sum / dim=64x128x128x128 axi=2 | 1.55 | 1.56 | 6.50 | 1.65 | 5.62 | 0% | +6% | +320% |
| Sum / dim=64x128x128x128 axi=3 | 1.56 | 1.56 | 6.51 | 2.55 | 4.70 | +0% | +63% | +318% |
| SumAll / dim=64x128x128x128 | 1.56 | 1.54 | 6.53 | 1.73 | 4.35 | +0% | +10% | +318% |
| SumAll / dim=1000000 | 0.18 | 0.18 | 0.05 | 0.21 | 0.08 | -1% | +12% | -71% |
| SumAll / dim=1000000x128 | 1.48 | 1.49 | 6.48 | 1.60 | 4.22 | 0% | +8% | +338% |
| SumAll / dim=128x1000000 | 1.53 | 1.47 | 6.21 | 1.71 | 4.14 | +4% | +11% | +304% |

Average benchmark:
| Operation | mlx_gpu | mlx_gpu_compile | mlx_cpu | mps | cpu | mlx_gpu_compile/mlx_gpu speedup | mlx_gpu/mps speedup | mlx_gpu/mlx_cpu speedup |
|-----------------|-------|---------------|-------|------|------|-------------------------------|-------------------|-----------------------|
| Argmax | 1.64 | 1.60 | 8.99 | 0.70 | 6.36 | +2% | -57% | +446% |
| BCE | 1.10 | 0.59 | 46.93 | 0.49 | 3.91 | +87% | -55% | +4155% |
| Concat | 3.33 | 3.36 | 94.19 | 3.37 | 23.88 | 0% | +1% | +2726% |
| Conv1d | 0.78 | 0.65 | 2.44 | 0.60 | 157.82 | +19% | -23% | +214% |
| Conv2d | 1.60 | 1.60 | 345.31 | 1.02 | 33.31 | 0% | -36% | +21493% |
| Gather | 1.08 | 1.06 | 3.54 | 7.97 | 7.22 | +1% | +635% | +226% |
| LeakyReLU | 0.29 | 0.26 | 0.71 | 0.37 | 0.99 | +11% | +25% | +144% |
| Linear | 5.66 | 5.62 | 23.21 | 4.08 | 78.78 | +0% | -27% | +310% |
| MatMul | 4.24 | 4.18 | 23.56 | 7.36 | 511.15 | +1% | +73% | +455% |
| PReLU | 0.34 | 0.30 | 3.01 | 0.32 | 1.00 | +14% | -5% | +788% |
| ReLU | 0.39 | 0.28 | 0.47 | 0.40 | 0.94 | +42% | +2% | +18% |
| Scatter | 0.24 | 0.25 | 25.40 | 1.40 | 0.91 | -4% | +480% | +10446% |
| ScatterSum | 0.03 | 0.03 | 0.01 | nan | 1.23 | +10% | nan% | -71% |
| ScatterMax | 0.03 | 0.03 | 0.01 | nan | 1.26 | +4% | nan% | -69% |
| SeLU | 0.41 | 0.32 | 4.27 | 0.40 | 3.39 | +28% | -2% | +952% |
| Sigmoid | 0.27 | 0.27 | 3.86 | 0.52 | 3.13 | 0% | +90% | +1313% |
| Softmax | 4.89 | 3.61 | 33.99 | 2.95 | 19.12 | +35% | -39% | +595% |
| Softplus | 0.27 | 0.27 | 28.09 | 0.41 | 4.35 | +0% | +49% | +10133% |
| Sort | 0.74 | 0.74 | 228.05 | 7.85 | 42.49 | 0% | +962% | +30748% |
| Sum | 1.55 | 1.55 | 6.51 | 1.89 | 6.69 | 0% | +22% | +320% |
| SumAll | 1.19 | 1.17 | 4.82 | 1.31 | 3.20 | +1% | +10% | +305% |

In diagram format

git: https://github.com/CAG12/python/tree/main/mlx_benchmark_visualization

--

--

Istvan Benedek

46 years-old computer/data scientist guy who just left his excellent job at a large insurance company to chase his dreams and deal with mathematics and AI