Pruning with Catalyst

Catalyst Team
Oct 23, 2020 · 6 min read

Hi! My name is Nikita. I am one of the Catalyst contributors. I want to tell you about pruning with PyTorch and Catalyst.

Image for post
Image for post

Do we need to go deeper?

In the past few years, state-of-the-art architectures became more and more complex. The number of parameters grows exponentially. But what if all networks are over-parameterized and more than a half parameters don’t influence the result? Several methods can help us. Not so far ago I wrote a post about one of them called knowledge distillation. You can find it here.

Today I will continue this series about reducing model size with an introduction to pruning neural networks. Let’s start!

If you didn’t hear about Catalyst before I recommend you read this post, which introduces ideas and minimal examples of this framework.

What is pruning?

If the network is over-parameterized, let’s try to simply null some parameters. The process of removing connections between neurons is called pruning. It takes the idea from the biology field. For instance, the human brain is also over-parameterized in the first stages of growth, and we are learning through pruning unnecessary connections.

But what can we do with this theory in practice? When it comes to neural networks, connections between neurons can be represented as a matrix. So the result of applying one layer is

Image for post
Image for post

Where “f” is a non-linear activation function, for example, ReLU.

Image for post
Image for post

We often have a bias term there, but let’s focus on the values of the matrix elements. In this case, we can represent matrix multiplication as

Image for post
Image for post

So we can assume that the less absolute value of “w” we have, the less influence this value has on the result. This method is called magnitude pruning. I will do all my experiments on the MNIST dataset. Here is a code sample:

Let’s try to prune these connections and see the results!

Image for post
Image for post

Even this simple method allows you to reduce the number of parameters by about 3 times without losing quality! But can we do better?

Iterative pruning

We can tune our network every time we applying pruning.

Image for post
Image for post

Let’s look at the results!

Image for post
Image for post

Lottery Ticket Hypothesis

If our network is over parameterized then maybe in the randomly initialized network there is already a subnetwork, that could solve our task more efficiently. All we need to do is to find this subnetwork. So after pruning we can restore initial weights but keep the pruning mask. And then tune our subnetwork.

Image for post
Image for post
Image for post
Image for post

As you can see the results are very close to iterative pruning.

Image for post
Image for post

This result is not very intuitive, but results are as good as in iterative pruning. The most interesting thing is that we could somehow guess the lottery ticket mask right from the start and reach the result with one iteration!

Hardware acceleration

Since sparse layers are not available in PyTorch (except nn.Embeddings) we can’t feel any acceleration or model size reduction, as we just replace some weights with zeros. It works even slower if we keep pruning the mask, as we need to execute some additional pre-forward hooks to apply this mask. What can we do?

Structured pruning

But wait, as we remember we can prune neurons instead of connections. What does it mean practically?

Image for post
Image for post

Every neuron is represented by a row in the weights matrix, so we need to apply pruning over the first dimension (dim=0). After pruning, we can remove the entire row and therefore reduce the number of operations!

However, there is one problem here. If we prune several neurons, the output shape will be different. In other words, if we remove a row in a weight matrix at say layer 1, we should also prune a column in the weights matrix from layer 2.
If we are talking about convolution layers, we can reduce the number of channels by pruning weight tensors over the first dim. But remember that if you want to speed up your model, you should remove columns or rows with zero values manually.

Do I need to try pruning in my case?

Imagine you have a big slow model and you are trying to speed it up. What should you try first? For now, the answer depends on the task, but pruning is not the first thing to try. Here are things I would try:

  1. Torchscript

When you convert your model to Torchscript, you could run it from various languages that are faster than Python (for example C++). It is a suitable solution for almost every case. Also, you can convert your model to special frameworks like ONNX. But what if it is not enough?

2. Quantization

Quantization is also a beta feature in PyTorch. But it works well in almost every case. For example, here is a tutorial with BERT. Don’t forget to apply the first step after quantization.

3. Try different architecture

For example, if you have enough data for text classification you can just replace your transformer with logistic regression on top of the tf-idf features and maybe the quality remains almost the same. You could check out Yury Kashnitsky's talk about such a case.

3. KD + Quantization

If you have enough time, you can try to transfer knowledge from the big model to a smaller student. But it requires more time, and probably there is no working pipeline for your concrete case.

4. Pruning + KD + Quantization

The last thing is the pruning. For example, you can try to prune your network and then try to add KD losses from the full model for tuning it to the quality you need.

That being said, pruning is not common in production today, but it is one of the most interesting areas for research. Someday there will be sparse layers in PyTorch and we can see that the newest GPU cards can handle sparse operation more effectively.

Code for all experiments above available here.

Thank you!
If you have any questions, feel free to join our Catalyst community Slack ;)

PyTorch

An open source machine learning framework that accelerates…

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store