Using CNNs to Calculate Attention| Building CvT from scratch using PyTorch | Paper explanation

Shubh Mishra
The Deep Hub
Published in
7 min readFeb 29, 2024

Hey 👋

Hope you doing great

In this part of the Vision Transformer series, I will explain the key idea behind introducing convolution to ViT models and Build it from scratch. So Let’s get straight to it!

Check out my previous work on Swin Transformer: https://medium.com/@mishra4475/building-swin-transformer-from-scratch-using-pytorch-hierarchical-vision-transformer-using-shifted-91cbf6abc678

Convolutional Vision Transformer

Honestly, the CvT architecture proposed in the paper is among the very first architecture that was refreshingly straightforward to comprehend in a single read, and without the need for supplementary tutorials or blogs. I highly recommend that you at least give it a read. https://arxiv.org/pdf/2103.15808.pdf

The Vision Transformer models are getting good recognition for computer vision tasks. However, the standard vision transformer still gets beat up by the standard CNN architecture for smaller datasets. Because CNNs have great local feature extraction capabilities due high correlation between their special neighbouring pixels. Also, another reason that benefits CNNs is the hierarchical structure of convolutional kernels that learn the visual patterns that take into account local spatial context at varying levels of complexity, from simple low-level edges and textures to higher-order semantic patterns. This hierarchical structure is considered a very profound reason that adds to the uniqueness of the CNNs, which is also addressed and introduced in the Swin Transformer Architecture, which I in detail covered in the previous blog.

Image taken from the original paper | Introducing Convolution to Vision Transformers

To introduce CNN in the standard ViT. Two main things are proposed.

  1. Convolutional Token Embedding
  2. Convolutional Transformer Block

The only thing new here is the Convolutional Transformer Block as I’ve already explained the Embedding module in the previous blogs.

Convolutional Token Embedding

This is just as same as the embedding we’ve been using till now, we take an image, pass it to a convolution layer, get a new feature image, and reshape that image to form the embeddings as shown in stage1 in the figure above.

class CVTEmbedding(nn.Module):
def __init__(self, in_ch, embed_dim, patch_size,stride):
super().__init__()
self.embed = nn.Sequential(
nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=stride)
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.embed(x) # self.embed(x) where x: B C H W
x = rearrange(x, 'b c h w -> b (h w) c') # i.e. x: B T(h w) C
x = self.norm(x)
return x

Here we need to notice that in each stage we do not use any constant stride value, thus the new dimension of the Image we get after conv embedding, the tokens in (B, T, C) is not going to be (H// patch size) as, in the standard Vision transformer, this is because we want to leverage the local feature extraction like the convolutional neural networks.

Convolutional Transformer Block

Figure showing different types of techniques to get the Query, Key, and Value for self-attention. Image taken from the original CvT paper.

And that is it, that’s the only thing we gonna do here, Instead of calculating the Query, Key, and value with a Linear layer as done in the standard Vision Transformer, we would pass the matrix to a Conv layer to get the same.

class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, num_heads, kernel_size=3, with_cls_token=False):
super().__init__()
padding = (kernel_size - 1)//2
self.forward_conv = self.forward_conv
self.num_heads = num_heads
self.with_cls_token = with_cls_token
self.conv = nn.Sequential(
nn.Conv2d(in_dim, in_dim, kernel_size=kernel_size, padding=padding, stride=1),
nn.BatchNorm2d(in_dim),
Rearrange('b c h w -> b (h w) c')
)
self.att_drop = nn.Dropout(0.1)

def forward_conv(self, x):
B, hw, C = x.shape

H = W = int(x.shape[1]**0.5)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
q = self.conv(x)
k = self.conv(x)
v = self.conv(x)

return q, k, v

def forward(self, x):
# x -> B, (H, W), C

q, k, v = self.forward_conv(x)

q = rearrange(x, 'b t (d H) -> b H t d', H=self.num_heads)
k = rearrange(x, 'b t (d H) -> b H t d', H=self.num_heads)
v = rearrange(x, 'b t (d H) -> b H t d', H=self.num_heads)

att_score = q@k.transpose(2, 3)/self.num_heads**0.5
att_score = F.softmax(att_score, dim=-1)
att_score = self.att_drop(att_score)

x = att_score@v
x = rearrange(x, 'b H t d -> b t (H d)')

return x
  1. Here we first rearrange the input x of shape (B, T, C) to (B, C, H, W) to pass it down the conv layer.

2. We get the same shape as the input and rearrange the tensor (B, C, H, W) back to (B, T(H W), C).

Note: Here the input x is the embedding of the Image. When we perform a rearrange operation, (i.e. x: B T C → B C H W) we do not get the representation of the exact image back, because the tokens are of the shape (num_patches*num_patches), where num patch = H/patch_size.

Also, We need to contact the class token before the tokens in the input feature as it's done in the standard ViT.

X = cls_token + X: (B T C) (new X shape: (B, (T + 1), C)

But here in the attention head, we are passing the input to a conv layer, and we won’t be able to reshape token T(h w) to h, w unless they are perfect squares or even numbers. So that’s why we take the csl_token out right before passing it to a conv layer and concat it back after getting our values.

class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, num_heads, kernel_size=3, with_cls_token=False):
super().__init__()
padding = (kernel_size - 1)//2
self.forward_conv = self.forward_conv
self.num_heads = num_heads
self.with_cls_token = with_cls_token
self.conv = nn.Sequential(
nn.Conv2d(in_dim, in_dim, kernel_size=kernel_size, padding=padding, stride=1),
nn.BatchNorm2d(in_dim),
Rearrange('b c h w -> b (h w) c')
)
self.att_drop = nn.Dropout(0.1)

def forward_conv(self, x):
B, hw, C = x.shape
if self.with_cls_token:
cls_token, x = torch.split(x, [1, hw-1], 1)
H = W = int(x.shape[1]**0.5)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
q = self.conv(x)
k = self.conv(x)
v = self.conv(x)

if self.with_cls_token:
q = torch.cat((cls_token, q), dim=1)
k = torch.cat((cls_token, k), dim=1)
v = torch.cat((cls_token, v), dim=1)

return q, k, v

def forward(self, x):
# x -> B, (H, W), C

q, k, v = self.forward_conv(x)

q = rearrange(x, 'b t (d H) -> b H t d', H=self.num_heads)
k = rearrange(x, 'b t (d H) -> b H t d', H=self.num_heads)
v = rearrange(x, 'b t (d H) -> b H t d', H=self.num_heads)

att_score = q@k.transpose(2, 3)/self.num_heads**0.5
att_score = F.softmax(att_score, dim=-1)
att_score = self.att_drop(att_score)

x = att_score@v
x = rearrange(x, 'b H t d -> b t (H d)')

return x

And that’s it, now we just need to implement the basic architecture.

Vision Transformer Block

class MLP(nn.Module):
def __init__(self, dim):
super().__init__()
self.ff = nn.Sequential(
nn.Linear(dim, 4*dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(4*dim, dim),
nn.Dropout(0.1)
)

def forward(self, x):
return self.ff(x)


class Block(nn.Module):
def __init__(self, embed_dim, num_heads, with_cls_token):
super().__init__()

self.mhsa = MultiHeadAttention(embed_dim, num_heads, with_cls_token=with_cls_token)
self.ff = MLP(embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(0.1)

def forward(self, x):
x = x + self.dropout(self.mhsa(self.norm1(x)))
x = x + self.dropout(self.ff(self.norm2(x)))
return x

This is just a standard encoder block as proposed in the paper and used in any other vision transformer architecture. If you are completely new to Vision Transformer you might want to check my one of my previous articles: https://medium.com/@mishra4475/building-vision-transformer-from-scratch-using-pytorch-an-image-worth-16x16-words-24db5f159e27

Here is the vision transformer for a single-stage:

class VissionTransformer(nn.Module):
def __init__(self, depth, embed_dim, num_heads, patch_size, stride, in_ch=3, cls_token=False):
super().__init__()

self.stride = stride
self.cls_token = cls_token
self.layers = nn.Sequential(*[Block(embed_dim, num_heads, cls_token) for _ in range(depth)])
self.embedding = CVTEmbedding(in_ch, embed_dim, patch_size, stride)

if self.cls_token:
self.cls_token_embed = nn.Parameter(torch.randn(1, 1, 384))

def forward(self, x, ch_out=False):
B, C, H, W = x.shape
x = self.embedding(x)
if self.cls_token:
cls_token = repeat(self.cls_token_embed, ' () s e -> b s e', b=B)

x = torch.cat([cls_token, x], dim=1)

x = self.layers(x)

if not ch_out:
x = rearrange(x, 'b (h w) c -> b c h w', h=(H -1)//self.stride, w=(W-1)//self.stride)
return x

Here we rearrange the x to (B C H W) to pass it to the next stage.

Putting it all Together

Finally, we create the CvT module and Implement the Convolution Vision Transformer architecture

class CvT(nn.Module):
def __init__(self, embed_dim, num_class):
super().__init__()

self.stage1 = VissionTransformer(depth=1,
embed_dim=64,
num_heads=1,
patch_size=7,
stride=4,
)
self.stage2 = VissionTransformer(depth=2,
embed_dim=192,
num_heads=3,
patch_size=3,
stride=2,
in_ch = 64)
self.stage3 = VissionTransformer(depth=10,
embed_dim=384,
num_heads=6,
patch_size=3,
stride=2,
in_ch=192,
cls_token=True)
self.ff = nn.Sequential(
nn.Linear(6*embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, num_class)
)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x, ch_out=True)
x = x[:, 1, :]
x = self.ff(x)
return x

As mentioned above the stride in each stage is kept different for the sake of learning the local features just like the CNNs.

Each stage has in_ch as the out channels of the previous stage i.e. the embedding dimension of the previous vision transformer.

We then pass only the cls_token to a feed-forward network to get the final output.

if __name__ == '__main__':
# Usage example
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(1, 3, 224, 224).to(device)
embed_size = 64
num_class = 10
model = CvT(embed_size, num_class).to(device)
print(model(x).shape)

And That is it.

If you liked my work consider giving me a follow or clap. It means a lot when you get appreciated for your work and it motivates me to keep writing these useful blogs for you. You can check out the entire code on my GitHub repository: https://github.com/mishra-18/ML-Models/blob/main/Vission%20Transformers/cvt.py

Thanks for reading.

--

--