Sparse Weight Activation Training- Reduce memory and training time in Machine Learning
Sparsity is one of the next frontiers in Deep Learning. Don’t sleep on it.
To help me understand you fill out this survey (anonymous)
A little bit ago, I covered Google AI’s pathways architecture, calling it a revolution in Machine Learning. One of the standouts in Google’s novel approach was the implementation of sparse activation in their training architecture. I liked this idea so much that I decided to explore this in a lot more depth. That’s where I came across Sparse Weight Activation Training (SWAT), by some researchers at the Department of Electrical And Computer Engineering, University of British Columbia. And the paper definitely has me excited.
For ResNet-50 on ImageNet SWAT reduces total floating-point operations (FLOPS) during training by 80% resulting in a 3.3× training speedup when run on a simulated sparse learning accelerator representative of emerging platforms while incurring only 1.63% reduction in validation accuracy. Moreover, SWAT reduces memory footprint during the backward pass by 23% to 50% for activations and 50% to 90% for weights.
In this paper, I will be sharing some key insights from this paper. As we see ML operations scaling up, algorithms manipulating sparsity will be on the forefront of the cutting edge. You definitely don’t want to miss out.
Understanding the need for Sparse Activation
To understand why Sparse Activation and SWAT are so cool, think back to how Neural Networks work. When we train them, input flows through all the neurons, both in the forward and backward passes. This is why adding more parameters to a Neural Network adds to the cost exponentially.
Adding more neurons to our network allows for our model to learn from more complex data (like data from multiple tasks and data from multiple senses). However, this adds a lot of computational overhead.
Sparse Activation allows for a best of both worlds scenario. Adding a lot of parameters allows for our model to learn more tasks effectively (and make deeper connections). Sparse Activation lets you This allows the network to learn and get good at multiple tasks, without being too costly. The following video is more dedicated to that idea.
The conception kind of reminds me of a more modern twist on the Mixture of Experts learning protocol. Instead of deciphering which expert can handle the task best, we are instead routing the task to the part of the neural network that handles it best. This is similar to our brain, where different parts of our brain are good at different things. It’s not a coincidence that MoE is itself making a small comeback in large-scale model training. Delegation of tasks to smaller experts or sub-networks is an amazing way to balance scale and cost.
Now that you’re sold on the amazing world of sparsity, let’s dive into SWAT and what it does differently.
Breaking down SWAT
As the name suggests, SWAT sparsifies the weights and activation of different neurons. The process is relatively intuitive. It assumes that the biggest magnitudes are the most important. By the 80–20 principle, we can use only these important values, and set the other, less influential values to 0, eliminating them.
This is not a very hard algorithm to conceptualize. However, there are a few design choices that you would need to implement when applying such sparsification. We can choose to drop either weights, activations, gradients (calculated during backpropagation), or some combination of them. The SWAT team conducted a sensitivity analysis, checking how convergence was affected by each of them.
Figure 2 is an interesting one. The difference b/w dropping gradients and dropping weights+activations is clear. The former wrecks your performance. The authors themselves point out this phenomenon- ‘The “sparse weight and activation” curve shows that convergence is relatively insensitive to applying Top-K sparsification. In contrast, the “sparse output gradient” curve shows that convergence is sensitive to applying Top-K sparsification to back-propagated error gradients (5al ). The latter observation indicates that meProp, which drops back-propagated error-gradients, will suffer convergence issues on larger networks.’
This gives us a good starting off point- ‘In the forward pass use sparse weights (but not activations) and in the backward pass use sparse weights and activations (but not gradients).’ And that my lovely reader, is the basis of SWAT, explained simply.
There is another idea that I found extremely important. And that is how this team resurrects the dead. Yes, I’m completely serious.
Using Zombies and reviving Dead Neurons
During both forward and back propagation, only the most important weights are used for calculations. Most people would just stop here, and ignore the dead neurons for training. This is how most network pruning happens. However, these researchers are also woke.
They update both the active and the dead weights with the dense gradient calculations (remember we have already established that gradients should not be dropped). This adds a comeback mechanism of sorts for the previously dead weights. Just because it is dead for one iteration, doesn’t mean it won’t show up another time. This allows the network to explore network topologies (structures) dynamically.
This allows the algorithm to perform a beautiful balancing act. Pruning and dropout can be used to stop overfitting, improve generalization, and reduce the costs of training. However, reducing connectivity is tricky and tends to increase training loss. Especially if done wrong. This approach has the same effect as removing layers/neurons but dynamically updates to find the best configuration. Below is a full description of the algorithm-
Now for the final bit, let’s evaluate the results of SWAT on a bunch of tasks. Speedups and memory efficiency are useless if not backed by a great performance.
SWAT on tasks
The first graphic to look at compares the drop in accuracy vs reduction in training time/cost. If SWAT can reduce costs while keeping a reasonable performance, it will get a pass. SWAT is compared to it’s competitors twice, using sparsities of both 80% and 90%.
The performance is quite impressive. Given that most of the baselines are already very good, the slight reduction in cost at 80% (or even 90%) sparsity is not a huge concern. The 8x reduction of 90% SWAT-U is also pretty exciting and makes a case for this algorithm to be explored further. Next, let’s look at some raw numbers. Take a look at the following analysis from the authors-
For those of you curious, these are the aforementioned tables-
These results are quite spectacular. However, there are a lot for explorations. I would’ve liked to see more comparisons, approaching multiple kinds of tasks and policies. Given the utility of Data Augmentation in vision tasks these days, it would be interesting to compare how Sparsity would play a role there. What about tasks like generation, segmentation, etc? I think there are many areas where we can test out SWAT.
That being said, this paper is an amazing first step. The authors have established a pretty exciting algorithm and I will definitely be looking into this further. If you have any experience with SWAT or other sparsity-oriented algorithms/procedures share them with me. I’m definitely looking to learn a lot more.
That’s it for this article. If you’re looking to get into ML, this article gives you a step-by-step plan to develop proficiency in Machine Learning. It uses FREE resources. Unlike the other boot camps/courses, this plan will help you develop your foundational skills and set yourself up for long-term success in the field.
For Machine Learning a base combining Software Engineering, Math, and Computer Science is crucial. It will help you conceptualize, build, and optimize your ML. My daily newsletter, Coding Interviews Made Simple covers topics in Algorithm Design, Math, Recent Events in Tech, Software Engineering, and much more to make you a better developer. I am currently running a 20% discount for a WHOLE YEAR, so make sure to check it out.
I created Coding Interviews Made Simple using new techniques discovered through tutoring multiple people into top tech firms. The newsletter is designed to help you succeed, saving you from hours wasted on the Leetcode grind. I have a 100% satisfaction policy, so you can try it out at no risk to you. You can read the FAQs and find out more here
Feel free to reach out if you have any interesting jobs/projects/ideas for me as well. Always happy to hear you out.
For monetary support of my work following are my Venmo and Paypal. Any amount is appreciated and helps a lot. Donations unlock exclusive content such as paper analysis, special code, consultations, and specific coaching:
Reach out to me
Use the links below to check out my other content, learn more about tutoring, or just to say hi. Also, check out the free Robinhood referral link. We both get a free stock (you don’t have to put any money), and there is no risk to you. So not using it is just losing free money.
Check out my other articles on Medium. : https://rb.gy/zn1aiu
My YouTube: https://rb.gy/88iwdd
Reach out to me on LinkedIn. Let’s connect: https://rb.gy/m5ok2y
My Instagram: https://rb.gy/gmvuy9
My Twitter: https://twitter.com/Machine01776819
If you’re preparing for coding/technical interviews: https://codinginterviewsmadesimple.substack.com/
Get a free stock on Robinhood: https://join.robinhood.com/fnud75