TensorFlow Model Optimization Toolkit — Pruning API
Since we introduced the Model Optimization Toolkit — a suite of techniques that developers, both novice and advanced, can use to optimize machine learning models — we have been busy working on our roadmap to add several new approaches and tools. Today, we are happy to share the new weight pruning API.
Optimizing machine learning programs can take very different forms. Fortunately, neural networks have proven resilient to different transformations aimed at this goal.
One such family of optimizations aims to reduce the number of parameters and operations involved in the computation by removing connections, and thus parameters, in between neural network layers.
The weight pruning API is built on top of Keras, so it will be very easy for developers to apply this technique to any existing Keras training program. This API will be part of a new GitHub repository for the model optimization toolkit, along with many upcoming optimization techniques.
What is weight pruning?
Weight pruning means eliminating unnecessary values in the weight tensors. We are practically setting the neural network parameters’ values to zero to remove what we estimate are unnecessary connections between the layers of a neural network. This is done during the training process to allow the neural network to adapt to the changes.
Why is weight pruning useful?
An immediate benefit from this work is disk compression: sparse tensors are amenable to compression. Thus, by applying simple file compression to the pruned TensorFlow checkpoint, or the converted TensorFlow Lite model, we can reduce the size of the model for its storage and/or transmission. For example, in the tutorial, we show how a 90% sparse model for MNIST can be compressed from 12MB to 2MB.
Moreover, across several experiments, we found that weight pruning is compatible with quantization, resulting in compound benefits. In the same tutorial, we show how we can further compress the pruned model from 2MB to just 0.5MB by applying post-training quantization.
In the future, TensorFlow Lite will add first-class support for sparse representation and computation, thus expanding the compression benefit to the runtime memory and unlocking performance improvements, since sparse tensors allow us to skip otherwise unnecessary computations involving the zeroed values.
Results across several models
In our experiments, we have validated that this technique can be successfully applied to different types of models across distinct tasks, from image processing convolutional-based neural networks to speech processing ones using recurrent neural networks. The following table shows a subset of some of these experimental results.
How does it work?
Our Keras-based weight pruning API uses a straightforward, yet broadly applicable algorithm designed to iteratively remove connections based on their magnitude during training. Fundamentally, a final target sparsity is specified (e.g. 90%), along with a schedule to perform the pruning (e.g. start pruning at step 2,000, stop at step 10,000, and do it every 100 steps), and an optional configuration for the pruning structure (e.g. apply to individual values or blocks of values in certain shape).
As training proceeds, the pruning routine will be scheduled to execute, eliminating (i.e. setting to zero) the weights with the lowest magnitude values (i.e. those closest to zero) until the current sparsity target is reached. Every time the pruning routine is scheduled to execute, the current sparsity target is recalculated, starting from 0% until it reaches the final target sparsity at the end of the pruning schedule by gradually increasing it according to a smooth ramp-up function.
Just like the schedule, the ramp-up function can be tweaked as needed. For example, in certain cases, it may be convenient to schedule the training procedure to start after a certain step when some convergence level has been achieved, or end pruning earlier than the total number of training steps in your training program to further fine-tune the system at the final target sparsity level. For more details on these configurations, please refer to our tutorial and documentation.
At the end of the training procedure, the tensors corresponding to the “pruned” Keras layers will contain zeros according to the final sparsity target for the layer.
New documentation and Github repository
As mentioned earlier, the weight pruning API will be part of a new GitHub project and repository aimed at techniques that make machine learning models more efficient to execute and/or represent. This is a great project to star if you are interested in this exciting area of machine learning or just want to have the resources to optimize your models.
Given the importance of this area, we are also creating a new sub-site under tensorflow.org/model_optimization with relevant documentation and resources. We encourage you to give this a try right away and welcome your feedback. Follow this blog for future announcements!
Acknowledgements: Raziel Alvarez, Pulkit Bhuwalka, Lawrence Chan, Alan Chiao, Tim Davis, Suyog Gupta, Jian Li, Yunlu Li, Sarah Sirajuddin, Daniel Situnayake, Suharsh Sivakumar.