Training Compact Transformers from Scratch in 30 Minutes with PyTorch

Steven Walton
PyTorch
Published in
20 min readJun 28, 2021

Authors: Steven Walton, Ali Hassani, Abulikemu Abuduweili, and Humphrey Shi. SHI Lab @ University of Oregon and Picsart AI Research (PAIR)

In this tutorial we’ll introduce Compact Transformers — compute and data-efficient transformers — which the average person can train on their home computer (quickly) and get state of the art results for classification in both computer vision and NLP. If you’ve wanted to learn how to program and train transformers but don’t have a bunch of fancy GPUs, or if the task at hand requires you to work with smaller datasets, then this is the tutorial for you. You can also find our code and models on GitHub.

Diagram of Compact Transformers

Transformers: What Do They Know? Do They Know Things?? Let’s Find Out!

Transformers are the popular new kid on the block in machine learning. But who really are they and why are they here? Are they just some popular new tool or are they really a breakthrough? Are these models restricted to big labs with lots of computing resources and data or can they be useful to mere mortals like you and me?

Through this tutorial we will break down every part of the above diagram and figure out how to create our own Vision Transformers. We will then discuss how to create a more democratic model, which is significantly smaller and will allow you to train and test on your own machine, even if you don’t have the newest hardware or even if you don’t have a GPU. We’ll start with the Transformer Encoder and with the most difficult and important part first: Attention. So…

May I Have Your Attention Please?: The Attention Mechanism

When humans look at a scene, they focus on their field of view on only what is important in the scene. This is helpful because then we only need to process information about what is important and can ignore the rest. It would be advantageous if we could do something similar with machine learning.

Figure 1: Various kinds of attention

That similarity is called attention. There are many forms of attention, but one shared aspect stands out. They take the form:
attention = similarity(q,k)
Here q represents a query (a question) and k represents a key. We want similarity between these two things because we want to find the key that matches closest to our question/query. You can think about this as accessing a database. We have something we want to find, so we query the database looking for the information we want.

Looking at the different attention equations (Fig. 1), we may notice that they are all doing something similar. With regard to this post, we will only focus on the “Scaled Dot-Product Self-Attention” (Eq. 1), but we’ll break down the equation into its fundamentals; you will be able to apply this knowledge to help you understand the other types of attention.

Equation 1: Scaled Dot-Product Attention
Figure 2: Similarity of two vectors using inner product (cosine similarity)

First, let’s look at the inside, we see <q,k>. This notation means we’re taking the inner product, or dot product. There’s a geometric property to inner products that is helpful to us here: they tell us how similar two vectors are! If the vectors are completely orthogonal (perpendicular) then their inner product is 0. If they are exactly the same then their inner product is 1. As long as we are working with unit vectors (vectors with length 1) we can say that the inner product is determined by the angle between them. If we aren’t working with unit vectors, the scale between the vectors matters (look back up and see if you can find similar properties in the other forms of attention).

Okay, now let’s finish looking at the inside of the softmax function, we still have that √d term. That’s our scale. You can think of this as simply a nice way to make the numbers not so big. d represents the dimensionality of the vectors and so this will make computation a bit easier. The softmax can produce bad gradients if d is too large or small, so we just want to scale everything down a bit (scaling is a common procedure in data wrangling, but we always need to be careful how we scale to not induce biases).

Now let’s move out, we have our softmax function. We can actually think about this as a way to turn our function into a probability distribution. The softmax function means that every element is between 0 and 1, and the total sum is 1. Just like any probability distribution.

Equation 2: Softmax

Why do we want this? Well, because we want to tell our algorithm what to pay attention to. If our attention score is 0 then we don’t pay attention to those parts and if our attention is 1 then we only care about that one pixel (this would be the same as a one-hot vector). Having a probability distribution makes it easier for us to do our computation because we can keep all our fancy statistics (e.g. we can calculate the likelihood).

Now we have our last component, what makes this “self-attention?” This is actually the simplest part. It means that we’re applying this to… itself. Our vectors q and k are actually neural networks (typically linear) and if they have the same input (q(x), k(x)), then they are self attending (our v also needs the same input, v(x)).

So there we have Scaled (√d) Dot-Product (<q,k>) Self (q(x), k(x),v(x)) Attention (self-similarity). The last part we have is the value, v. Everything we’ve shown above just tells us ‘’how much’’ we want to attend to certain elements. We need v to tell us “what’’ to attend to!

There’s one more analogy I want to take you through before we move on to transformers, and that is why we use this query, key, value, “nonsense.” We can actually think about this as a database where we are trying to get values. To do that we have a query that helps us find a key which has a corresponding value. This should help us better connect together why we want some similarity between q and k. We just want to learn a way to make sure our queries grab the appropriate key.

So there we have it. I hope I was able to hold your attention through all of this, now for some code.

Code 1: Self-Attention

We’re exploiting a nice trick with linear layers here where we can just use one layer that acts the same as if we had 3 different layers, one for q, k, and v. Because there is no interdependence between layers we can think of these as modular. We’ll see a similar trick again later but with convolutions.

Vision Transformers: More Than Meets the Eye

Vaswani says Attention Is All You Need but we need a few more things for the transformer. For Vision Transformers we only need to look at the encoder but if you’re looking to learn about transformers in full this will still help.

Figure 3: Vaswani et al’s Attention is All You Need

We’re going to start from the bottom and move up. This may look intimidating but we’ve already covered the most complex part of this network, attention. We just need a little bit more to turn this into our transformer model.

Input Embeddings are the easiest part of the network. There are many ways to do this and you’ll have to experiment a bit. This is just a way to take your data and represent it in a different way. You can do this by patching (An Image is Worth 16x16 Words), convolutions, linear networks, or something else. We’ll discuss this more later, but for now just think about this as “taking in data.’’ We will refer to these embedded data as tokens.

Positional Embedding

Positional Embedding is an extremely simple concept, but is often trivialized in a way that may be intractable to newcomers, so let me break this down just a bit. If you don’t need this, just know that PE is a way for us to express how data lines up in a positional relationship. We need this because we’re working with a sequence of data and if we don’t add information about how that data lies within a sequence, we are going to have a hard time learning relationships about a sequence.

Let’s quickly look at the most common form of PE, the sinusoidal embedding.

Equation 3: Positional Embedding

If you’re not familiar with waves/signals then just know that we have two waves that are out of phase and that they have a high frequency (ωₖ). Even frequencies are given by sin and odds are by cos (sin and cos are 90° out of phase). With a high frequency and the phase offset we can convey a lot of information to our network. A simple illustration will show this much better than anything I can write, so here you go.

Figure 4: Positional Embedding visualization for different tokens

From this you can see that image we get a pretty expressive representation that is going to weigh each one of our positional parameters differently. Each dot represents an offset that we give to a token, at the corresponding position in its embedding space. Any function that accomplishes this task will do, even a learned one. Typically we just want something with a high frequency for stability. In practice most find that there isn’t much of a difference between learnable PE and a fixed one, so you will typically see the one above.

Code 2: Positional Embedding

Multi-Headed Self-Attention

We’re back to attention, but this time we have multiple heads. We’ve already learned the difficult parts, we just need to know what this multi-head means. Well, if we think about attention, we can say that it is showing relations of pairs of data. If we have two heads of attention then we can get pairs of pairs. Three? Pairs of pairs of pairs. And so on. You can see how this can be extremely expressive, but at the same time may introduce too much complexity. One head usually does pretty well, but this is just another hyper-parameter you are going to have to search. So let’s modify our above code to make this multi-headed attention. The only real differences we made was dividing the embedding dimension by the number of heads and subsequently changing the scaling factor. But modern attention networks, like the one you’ll find in PyTorch’s library, are a bit more complex than this. So let’s make those modifications while we’re at it.

Code 3: Multi-Headed Self-Attention (MSA)

We can see that we added a few things here. The bias becomes a parameter that we can use for some extra expressibility. Two important parts we add are the dropouts and the projection layer. Dropout is an important tool for preventing overfitting and improving generalization of networks. Introducing this here can have similar effects. The projection layer we use to project all our pairs of pairs of … of pairs back into our embedding dimension, allowing our network to weigh these features. This is helpful because, as you might imagine, not all n-pairs are as useful.

Encoder

Now that we got the complex part out of the way, let’s build the rest of the transformer encoder. Really all we need to do is put some things together and add some residual connections!

Code 4: Encoding Layer

The forward section has been split to look like the Transformer Encoder image from above. That’s it! We created a transformer encoder layer! You can also add dropout here. But from the code you can see that we take in our data, add some positional encoding, send through a MHA layer, with a residual connection, then we add and normalize, send through a feed forward network, add with the residual layer, and normalize again. Repeat this encoding layer as many times as you want, but watch out because the number of parameters will grow very quickly.

We should only have two questions left: “why the residual connections” and “why the normalizing layers?” The residual connections are a common deep learning technique that allows us to build deeper networks in a more stable manner, see ResNet. Residual layers are easier to optimize, so be sure to include them if you are using deep neural networks. The normalization returns our data back into a normal distribution, mean of 0 and variance of 1. Layer normalization is another technique that helps stabilize our network and increases our training speed.

16x16 Words, a Vision Transformers

Vision Transformers (ViT) are a class of transformers applied to vision problems. Dosovitskiy et al introduced “An Image is Worth 16x16 Words” where they showed the first image classifier composed entirely of transformers. But we’re experts on transformers now, so the question is: “how do we repeat their experiments?” (an important part of both research and learning is replicating results). Let’s look at the image they include.

Figure 5: Dosovitskiy et al’s Vision Transformer (ViT)

If you’re looking at this image and saying “wait! I can do this!” then congratulate yourself! If you don’t, then that’s okay, we’re going to talk about it anyways. What I’m not going to talk about is that big gray box labeled “Transformer Encoder” because we just finished talking about it ^_^. Do note that there is a normalization before the MHA and MLP, change the encoder accordingly.

The only thing left to discuss is patching. The title of the paper gets us half way there. “An Image is Worth 16x16 Words.” So let’s break the image into 16x16 blocks, with 3 channels of course. That’s pretty easy and we know how to do this with some basic python skills. Then we just need to flatten the image patches and send them through an embedding layer (standard for transformer models). So we’re done!

Each patch represents an element in the sequence (or “word”), and is sent through the embedding layer, which merely maps it from its original pixel space (16×16×3) to a d-dimensional space.

Figure 6: Patching Animated

Well, if we think back to how convolutions work, we can apply a trick here! A convolution is a sliding window which applies a kernel operation to that window of data. So why don’t we just make a window the size of our patch, slide it so there’s no overlapping, and make the kernel not operate on the image? A convolution operation with a stride and kernel size of 16 and d filters (d×(16×16×3)) would be equivalent to the patching layer. Boom! Patching with convolutions! We can write this simply as:

Code 5: Patching

Now we didn’t actually make the kernel a no-operation (ones matrix). We can use this as our embedding layer too. We got two birds with one stone! So all we do is flatten it and there we go. So now let’s make our vision transformer!

Code 6: Vision Transformer

Okay, I was a bit sneaky there and threw something in we didn’t see. If you look back at the ViT picture you’ll notice something tricky too. Under the “Patch + Position Embedding” there is a small asterisk and the MLP Head comes out of the left side and not the middle of the Transformer Encoder. That’s actually telling us some information. ViT, like many transformers, are using a class token as an extra learnable parameter (which we use a nn.Parameter to accomplish) and the MLP Head is at the left side to indicate that we’re going to take the data corresponding to that class token (also in that left most position). Finally, we can train this using nn.CrossEntropyLoss.

So that’s it! We’ve created a vision transformer. You can see that this is probably a lot easier than you were thinking it was going to be!

But there’s a big problem here. We have an important line hiding at the end of the introduction of the paper:

Transformers lack some of the inductive biases inherent to CNNs, such as translation equivariance and locality, and therefore do not generalize well when trained on insufficient amounts of data.

However, the picture changes if the models are trained on larger datasets (14M-300M images). We find that large scale training trumps inductive bias.

This may not be a problem for The Google Brain team, who has huge computers and lots of resources. But this is a big problem for most of us. I mean, we can’t even train ImageNet and that’s considered a “medium’’ sized dataset! So what do us measly mortals do? If transformers are data hungry, then what was the point of this? Academic curiosity? No! We’re going to make some changes that will overcome these problems and make vision transformers useful to you! Yes, even you without a GPU!

Vision Transformers for Mere Mortals: Compact and Efficient Transformers

Figure 7: Diagram of Compact Transformers

As we just showed, vision transformers (and transformers in general) are easy to understand and implement. However, the architecture is typically used in large scales using lots of training data (ViT, GPT, BERT, etc). Considering ViT, you get much better performance by pre-training on large datasets before applying to smaller ones (e.g. ImageNet -> CIFAR10). But don’t fret! We have a few tricks up our sleeves to fix some of these problems.

Smaller ViTs

ViT variants introduced in the original paper (ViT-Base, ViT-Large and ViT-Huge) prove very effective when trained on JFT-300M and transferred to other sets of data, whether small-sized (CIFAR-10) or medium-sized (ImageNet). The same models were also pre-trained on ImageNet, but achieved inferior results. Pre-training on smaller sets of data such as CIFAR-10 causes a drastic over-fit, to the point of little to no mistake on the training set, while performing horribly on the test set.

The base ViT variant (ViT-Base) is made of 12 layers with 12-heads. This model has roughly 86 million learnable parameters. That’s a lot! For comparison, ResNet50 has around 23 million, and a ResNet18 has around 11 million. Now the over-fitting makes sense, because the number of parameters is capable of learning so much more, it will end up memorizing things.

That’s why we can do something very simple to avoid that: smaller models. These are the sizes that we found useful:

Table 1: Small ViT parameters and CIFAR-10 Test Accuracy

Patch sizes were set to 4 instead of the original 16, because we’re dealing with much smaller images (32×32 instead of 224×224). This improves performance, but it may decrease (relative) speed, due to the increase in the sequence length.

Compact Transformers

Even the small ViTs have a lot of parameters for the relatively weak performances. We need some way to boost the performance without dramatically increasing the size of the model. Our Compact Transformers, introduced in Escaping the Big Data Paradigm with Compact Transformers, allow us to do exactly that. With an increase of 0.07M parameters we can increase the performance by over 17%! That makes this model small enough that you can train it (not just run it) on even old GPUs. In fact, we’ve been able to train this on a CPU (30 mins to 80%).

To achieve this performance we don’t need to do many modifications (remember, research is incremental. Even ViT was only slight modifications to the Transformer model).

Figure 8: Comparison of the standard Vision Transformer and our Compact Transformers

From the above diagram you may notice that there are only two big changes. The first, introduced in CVT, is the sequence pooling, or SeqPool for short. To do this we just need to add another linear layer (called the attention pool) to our constructor and in the forward we’re going to softmax this layer and then matrix multiply it with the input, x (you could even just use nn.Parameter but a linear layer has shown better performance). While ViT segmented out the corresponding class token location and passed that through to the MLP layers, we utilized all the information. We softmax this because it essentially tells our network which parts of the transformer encoder’s outputs to pay the most attention to. This means that parameters are not wasted in our network. Everything is used.

Code 7: Modified Transformer Classifier

The next innovation, as seen progressing to CCT, is also simple. Patching, as done in ViT and CVT, has some inherent problems. Our network will have a difficult time understanding boundary information between patches. Without doing something else we essentially are only performing attention on each patch and then determining how those attentions relate. We might refer to this as relational inductive bias. Instead, let’s use convolutions that overlap and use a MaxPool. This should induce some relational bias into our model and allow our Transformer to learn on information that is better embedded.

Code 8: Tokenizer

This is an important step forward, because the overlapping convolutions are invariant to spatial translations and have low relational inductive bias. That’s partly the reason why CNNs have been so successful in vision. There are also a few other minor changes here and there, which pushes the performance further, but this is the meat of the model. Here’s a summary of the different variants and their performances:

Table 2: CIFAR-10 accuracy for different CCT models

In Table 2 we notice that having a lower number of convolution layers actually helps increase performance for most models. Our best model gets above 95% accuracy on CIFAR-10 and has just over 3 million parameters. It should be noted that an extra 300 epochs only increases accuracy by ~0.5%. The only ResNet that beats us is ResNet-1001, which has 10M parameters and a maximum accuracy of 95.38%.

We’ve also put together a Google Colab Notebook which trains our smallest model for only 10 minutes on CIFAR-10 and reaches 80% top-1 accuracy. You can actually use this to view randomly drawn images from CIFAR-10 along with their original and predicted class labels.

Adding Augmentations and Tuning (update)

While we’re doing great so far, there are some basic tricks that we can do to push the performance even further. We found that weight decay plays an extremely important role in accuracy, so if you’re training models yourself play around with this hyperparameter. Additionally, augmentations are key! We added: MixUp, CutMix, RandAugment (previously we used autoaugment), and Random Erasing. These can all easily be incorporated with the Timm Training Script.

Table 3: Top-1 validation accuracy comparisons. † variants used a batch size of 64 instead of 128.  ◇ variants used lower weight decay (1e −4) and were run for 500 epochs. ★variants were trained longer (5k epochs) with extra augmentations

From this we can see that these modifications have big changes on the accuracy while not affecting the number of parameters or compute (MACs). We did find, though, that with these augmentations we benefited from longer training times, upwards of 5000 epochs. In table 6 we show the difference in the number of epochs given these changes. As can be seen, this is going to be dependent upon the specific problem you are solving. We did not find any changes when training longer than 5k epochs.

Table 4: Comparison of CCT-7/3x1 with augmentations but different training times.

Extending to ImageNet

Our model also extends to larger datasets, such as ImageNet. The current model scales pretty fast, with the number of parameters increasing greatly. This is because the embedding dimension increases as our model size does. Like many that are reading this tutorial, we too are limited by computational resources and have not done as much of a hyper-parameter search as was done on CIFAR10 and CIFAR100. Regardless, we show that our model is still competitive. Compared to ViT we are getting higher accuracy with only a quarter of the number of parameters and lower complexity (MACs).

Table 5: ImageNet Top-1 validation accuracy comparisons. ↑ indicates pretraining on 224×224 and fine-tuning to 384×384. ★ indicates training with extra augmentation.♡means we trained for 300 epochs + 10 cooldown epochs.

We note that ViT trains on images of size 224x224 and fine-tunes on images of size 384x384 (indicated with ↑384). Training on smaller images is often easier as you can have larger batch sizes and frequently this can help speed up training and reduce gradient explosions. As per the paper they report that vision transformer benefit from larger image sizes. Were ViT to only report results on images of size 224x224 their model would have 86.54M Params and have 16.84 GMACs. This should be clear if you understand the attention mechanisms, as described above, because larger images hold more information and will thus created longer embedding sequences. Similar to this we show that such change in image sizes (↑384) can cause an improvement in accuracy of up to 2%! There is a cost to this though, as it also increases the number of parameters (slightly) and slows down the network (~3x GMACs).

This tradeoff between accuracy vs params/compute is one that often has to be taken into consideration when picking the right model for your own use. For example, MobileNet is often used on mobile devices but is trailing the SOTA ImageNet accuracy (no extra data) by ~12%. The reason being that MobileNet is far smaller and faster than SOTA ImageNet models. When picking models, these factors are extremely important to consider. Be sure to understand the hardware that your models will be deployed on when making these decisions.

But Wait, There’s More!: Natural Language Processing

We can also apply this type of network to Natural Language Processing, specifically in text classification. You may want to use this kind of thing for Spam Detection or Sentiment Analysis. Basically we’re limiting ourselves to NLP tasks that are analogous to our vision tasks, classification.

But we gotta make a big change here. Words are… well words and not numbers. In an image we have pixels which are defined by numbers in the range of 0 to 255 and using 3 channels for Red, Green, and Blue. So we can define the color red with the vector (255, 0, 0). But how do we define a word this way? How do we get the computer to understand our inputs?

The best thing to do is use a pretrained word embedding like GloVe, but you can also train your own. GloVe has 70,000 words and each word is embedded into a 300 element word vector. So this gives us an extra 21 million parameters! That’s almost 100x our smallest model! But if it is pre-trained you won’t have to worry much except for the storage size. But maybe you have a limited vocabulary and don’t want to waste all that space, look no further, we got your back! We actually don’t have to do very much because PyTorch is kind enough to provide us with an embedding function. All we really need to tell it is what the vocabulary size is (how many words we know) and the embedding dimension. Of course, a larger embedding dimension will allow it to be a bit more expressive but will take more space, so play around with this to find what is right for you. 300 is a typical number, so start there.

Code 9: Word Embedder

So we can use this function to create our word embeddings. In our network we will treat embedded words like our image patches. From this function you can also see that we can pass in some pretrained values, so for example if we used those from GloVe. We can freeze our model, so it doesn’t learn, by turning off the requires_grad function.

Because we changed our embedding space we need to change some things. We no longer have a patch embedding (hard to take patches of words!) so we’ll call this a Tokenizer, to keep with NLP terminology. The formatting is a little different but for the most part this is the same. The difference is that we need 1D convolutions here. Additionally, we need to enable masking in case we only want to use words that we have previously seen.

Code 10: Updated Tokenizer for NLP

After this we’re all set to go! We’ll test it, using GloVe, on two different datasets, AG_News and TREC (both of which you can get using the PyTorch torchtext classification datasets). We won’t be counting the parameters from GloVe for our model size since this is a constant. As you can see, our model also performs well as a text classifier.

Table 6: NLP Text Classification Results on AG News and TREC

Conclusion

Hopefully we’ve demystified the Vision Transformer for you, and made you feel that you too can build this powerful model. Because of Transformers’ lack of a local inductive bias, it has been believed that these architectures are extremely data-hungry, and thus also computationally hungry. As we’ve shown, you don’t need massive resources or data to make effective models, and have dispelled the myth of the “data-hungry transformer,” even allowing training on a CPU! We hope that this work will help those with limited data and resources to contribute to transformer research as well and encourage everyone to play around with these networks. In this way we can help democratize AI research, especially with respect to Transformers.

We have open sourced the code in our paper to allow anyone to view or modify the source code to help verify our results and to encourage experimentation. If you would like to follow more of our work you can check out the SHI Lab at the University of Oregon, checkout our GitHub for more open sourced works, or follow me on Twitter for updates and new work.

If you found this article useful please cite us

@article{walton2021Escaping,
title = "Training Compact Transformers from Scratch in 30 Minutes with PyTorch",
author = "Steven Walton and Ali Hassani and Abulikemu Abuduweili and Humphrey Shi",
journal = "medium.com/pytorch",
year = "2021",
url = "https://medium.com/pytorch/training-compact-transformers-from-scratch-in-30-minutes-with-pytorch-ff5c21668ed5"
}

--

--

Steven Walton
PyTorch
Writer for

Ph.D. Student @ University of Oregon Studying Computer Vision | SHI Lab | Twitter: @WaltonStevenj