Reconstruct The Complete Image Just from a Few Patches| Building Masked Autoencoders As Scalable learner

Shubh Mishra
The Deep Hub
Published in
7 min readMar 7, 2024

Hey 👋

Hope you doing great!

Till now we have converted various important ViT architecture in great detail. In this part of the Vision Transformer series, I will build the Masked Autoencoder Vision Transformer from scratch using PyTorch. Without further ado let's get straight to it!

Masked Autoencoders

The Mae is a Self Supervised Learning approach, meaning it doesn’t have any pre-labeled target data but rather utilizes the input data while training. This approach mainly involves masking 75% of the patches of an image. Thus after creating patches (H/patch size, W/patch size) where H and W are the height and width of the image, we mask 75% of the patches and only use the rest of the patches and feed it to the standard ViT.

The main goal here is to reconstruct the missing patches with only the known patches in the image.

Input(75% patches masked out) | target(reconstruct the missing pixels)

The MAE mainly includes these three components

  1. Random Masking
  2. Encoder
  3. Decoder

Random Masking

This is just as simple as selecting random patches of an image and then mask around 3/4th of them. However, the official implementation uses a different yet efficient technique

def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""

B, T, D = x.shape
len_keep = int(T * (1 - mask_ratio))

# creating noise of shape (B, T) to latter generate random indices
noise = torch.rand(B, T, device=x.device)

# sorting the noise, and then ids_shuffle to keep the original indexe format
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

# gathering the first few samples
ids_keep = ids_shuffle[:, :len_keep]
x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([B, T], device=x.device)
mask[:, :len_keep] = 0

# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)

return x, mask, ids_restore
  1. Let’s assume the input shape is (B, T, C). Here we first create a random tensor of shape (B, T) and then pass it to argsort, this will get us a sorted tensor of indices—for example, torch.argsort([0.3, 0.4, 0.2]) = [2, 0, 1].
  2. We also pass ids_shuffle to another argsort to get ids_restore. This just has the original format of the indices.
  3. Next, we gather the tokens we want to keep.
  4. Generate the binary mask and mark the tokens to keep as 0 and not as 1.
  5. Finally, unshuffle the mask, here the ids_restore that we created would come in handy to generate the representation, the mask should have. i.e. what indices of tokens are masked as 0 or 1 concerning the original input?

NOTE: Instead of creating random patches at random locations, the official implementation uses a different technique.

  1. Generate random indices for the image. As we did in ids_shuffle. Then get the first 25% of the indices (int(T*(1–3/4)) or int(T/4). We only use the first 25% of random indices and mask the rest.
  2. We then reorder the mask (unshuffle) with the help of the original order of the indices we had in ids_restore. Thus before gathering, the mask had the first 25% as 0. But remember these are random indices, that’s why we reorder to get the mask in the exact index it's supposed to be.

Encoder

class MaskedAutoEncoder(nn.Module):
def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(emb_size = emb_size)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])

def encoder(self, x, mask_ratio):
x = self.patch_embed(x)

cls_token = x[:, :1, :]
x = x[:, 1:, :]

x, mask, restore_id = random_masking(x, mask_ratio)

x = torch.cat((cls_token, x), dim=1)

x = self.encoder_transformer(x)

return x, mask, restore_id

1. The PatchEmbedding and Block are the standard implementations in the ViT model. I’ve covered them before you can check them out here:https://medium.com/thedeephub/building-vision-transformer-from-scratch-using-pytorch-an-image-worth-16x16-words-24db5f159e27

2. We first get the patch embeddings of our image (B, C, H, W) → (B, T, C) the implementation of PatchEmbedding here, also return the cls_token concatenated in the embedding tensor x. You can use the timm library to get the standard PatchEmbed and Block if you want to but this works the same. i.e. from timm.models.vision_transformer import PatchEmbed, Block

3. As we already have the cls_token we would want to remove It first and then pass it to generate masking. x: (B K C), mask: (B T) restore_id (B T) where K is the length of the tokens kept i.e. T/4.

4. We then concatenate the cls_token and pass it down to a standard encoder_transformer.

Decoder

The Decoding stage involves changing the input embedding dimensions to decoder_embedding_size. Recall that the input dimension is (B, K, C) where K is T/4. Thus we concat the unmasked patches with masked patches and then feed them into another Vision Transformer model (decoder) as shown in Figure 1.

class MaskedAutoEncoder(nn.Module):
def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(emb_size = emb_size)
self.decoder_embed = nn.Linear(emb_size, decoder_emb_size)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2 + 1, decoder_emb_size), requires_grad=False)
self.decoder_pred = nn.Linear(decoder_emb_size, patch_size**2 * in_channels, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
self.decoder_transformer = nn.Sequential(*[Block(decoder_emb_size, num_head) for _ in range(decoder_num_layers)])
self.project = self.projection = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=patch_size**2 * in_channels, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)

def encoder(self, x, mask_ratio):
x = self.patch_embed(x)

cls_token = x[:, :1, :]
x = x[:, 1:, :]

x, mask, restore_id = random_masking(x, mask_ratio)

x = torch.cat((cls_token, x), dim=1)

x = self.encoder_transformer(x)

return x, mask, restore_id

def decoder(self, x, restore_id):

x = self.decoder_embed(x)

mask_tokens = self.mask_token.repeat(x.shape[0], restore_id.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=restore_id.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)

# add pos embed
x = x + self.decoder_pos_embed

x = self.decoder_transformer(x)

# predictor projection
x = self.decoder_pred(x)

# remove cls token
x = x[:, 1:, :]

return x

1. We pass the input to the decoder_embed. Then we create mask_tokens for all the tokens that we masked and concatenate it with the original input x excluding its cls_token.

2. The tensor now has the first K unmasked token and the rest masked tokens, but now we would want to reorder them in the exact order of the indices. We can do so with the help of ids_restore.

3. Now ids_restore has the indices which when passed to torch.gather will unshuffle the input. Thus the unmasked tokens that we selected in random_masking (The first few random indices in ids_shuffle) are now rearranged in the exact order they are supposed to be. Later we concatenate the cls_token again with the reordered patches.

4. We now feed this entire input to a standard Vision Transformer and remove the cls_token and return the tensor x to calculate the loss.

Loss function

The Masked Autoencoder is trained on masked and unmasked patches and learns to reconstruct the images in the masked patches. The Loss function used in the Masked Autoencoder Vision Transformer is Mean Squared Error.

class MaskedAutoEncoder(nn.Module):
def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(emb_size = emb_size)
self.decoder_embed = nn.Linear(emb_size, decoder_emb_size)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2 + 1, decoder_emb_size), requires_grad=False)
self.decoder_pred = nn.Linear(decoder_emb_size, patch_size**2 * in_channels, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
self.decoder_transformer = nn.Sequential(*[Block(decoder_emb_size, num_head) for _ in range(decoder_num_layers)])
self.project = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=patch_size**2 * in_channels, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)

def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""

B, T, D = x.shape
len_keep = int(T * (1 - mask_ratio))

# creating noise of shape (B, T) to latter generate random indices
noise = torch.rand(B, T, device=x.device)

# sorting the noise, and then ids_shuffle to keep the original indexe format
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

# gathering the first few samples
ids_keep = ids_shuffle[:, :len_keep]
x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([B, T], device=x.device)
mask[:, :len_keep] = 0

# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)

return x, mask, ids_restore

def encoder(self, x, mask_ratio):
x = self.patch_embed(x)

cls_token = x[:, :1, :]
x = x[:, 1:, :]

x, mask, restore_id = self.random_masking(x, mask_ratio)

x = torch.cat((cls_token, x), dim=1)

x = self.encoder_transformer(x)

return x, mask, restore_id

def decoder(self, x, restore_id):

x = self.decoder_embed(x)

mask_tokens = self.mask_token.repeat(x.shape[0], restore_id.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=restore_id.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)

# add pos embed
x = x + self.decoder_pos_embed

x = self.decoder_transformer(x)

# predictor projection
x = self.decoder_pred(x)

# remove cls token
x = x[:, 1:, :]

return x

def loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, patch*patch*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.project(imgs)

loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch

loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss

def forward(self, img):
mask_ratio = 0.75

x, mask, restore_ids = self.encoder(img, mask_ratio)
pred = self.decoder(x, restore_ids)
loss = self.loss(img, pred, mask)
return loss, pred, mask

1. Training the Vision Transformer model on the unmasked patches,

2. Reordering the output of the unmasked patches with the masked patches.

3. Training the Vision Transformer model on both the masked and unmasked patches combined in their original form.

4. Calculating Mean Squared Error Loss about the last dimension of predicted output from the decoder (B, T, decoder embed) and the original patch embedding of the image (B, T, patch embedding)

Thanks for reading. I hope you enjoyed it. Please consider giving this one a clap and follow me for more. This motivates me to keep writing great stuff. You can check the entire code on my GitHub: https://github.com/mishra-18/ML-Models/blob/main/Vission%20Transformers/mae.py

--

--