Building the DINO model from Scratch with PyTorch: Self-Supervised Vision Transformer

Shubh Mishra
The Deep Hub
Published in
9 min readJun 9, 2024
Dog sprint DINO model output

Self-Distillation with No labels (DINO)

This is my continuation of the vision transformer series where I explain the most important architecture and their implementation from scratch. Here, in this blog we will talk about the self supervised DINO model from scratch. You can check out the previous blogs here: https://medium.com/thedeephub/building-mae-vision-transformer-from-scratch-using-pytorch-masked-autoencoders-are-scalable-2c2e78e0be02

Self Supervised learning

Self-supervised learning (SSL) is a type of machine learning where the model learns to understand data without requiring manually labeled examples. Instead, it generates its supervisory signals from the data itself. This approach is beneficial when labeled data is limited and expensive to obtain.

In SSL, the learning process involves creating tasks where the input data can be used to predict parts of the data itself. Common techniques include:

  • Contrastive Learning: The model learns by distinguishing between similar and dissimilar pairs of data.
  • Predictive Tasks: The model predicts a part of the input data from other parts, such as predicting the next word in a sentence or the context of a word from its surroundings.

DINO Model

The DINO (Distillation with No Labels) model is a cutting-edge self-supervised learning method applied to vision transformers (ViTs). It represents a significant advancement in the field of computer vision, enabling models to learn effective image representations without requiring any labeled data. Developed by researchers at Facebook AI Research (FAIR), DINO leverages a student-teacher framework and innovative training techniques to achieve great performance on various vision tasks.

Student-Teacher Network

In the DINO model, the student-teacher network is a core mechanism that enables self-supervised learning without labeled data. This framework involves two networks: the student network and the teacher network. Both networks are vision transformers, which are designed to process images by treating them as sequences of patches, similar to how transformers handle text sequences.

The student network is tasked with learning to generate meaningful representations from input images. The teacher network, on the other hand, provides target representations that the student network aims to match. The teacher network is not a static entity; it evolves over time by gradually incorporating the parameters of the student network. This is done using a technique called exponential moving average, where the teacher’s parameters are updated to be a weighted average of its current parameters and the student’s parameters.

The objective is to minimize the discrepancy between the student’s representations and the teacher’s representations for the same augmented image views. This is typically achieved using a loss function that encourages alignment between the student and teacher outputs while ensuring that the representations of different images remain distinct.

By continuously updating the teacher network based on the student network’s learning progress and training the student network to match the teacher’s output, DINO effectively leverages the strengths of both networks. The teacher network provides stable and consistent targets for the student, while the student network drives the learning process. This collaborative setup allows the model to learn robust and invariant features from the data without the need for manual labels, thereby achieving effective self-supervised learning.

Augmented input for student and teacher

In the DINO model, X1 and X2 (see the figure above)​ refer to different augmented views of the same original image X. These views are used as inputs for the student and teacher networks, respectively. The goal is to have the student network learn to produce consistent representations despite these augmentations.

Both student and teacher models receive different augmentation according to the following strategy.

  • Global Crops: Two global crops are created from the original image. These are larger crops that cover a substantial portion of the image, typically with a high overlap with the original image. Along with other augmentations such as color jittering, Gaussian blur, flipping, etc.
  • Local Crops: In addition to the global crops, the teacher network receives several local crops. These are smaller crops that focus on different parts of the image, capturing more localized details.

Here’s how we’ll define these augmentations for argument images which contain a batch of images we want to transform during training.

# These augmentations are defined exactly as proposed in the paper
def global_augment(images):
global_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.4, 1.0)), # Larger crops
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # Color jittering
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return torch.stack([global_transform(img) for img in images])

def multiple_local_augments(images, num_crops=6):
size = 96 # Smaller crops for local
local_transform = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.05, 0.4)), # Smaller, more concentrated crops
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # Same level of jittering
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Apply the transformation multiple times to the same image
return torch.stack([local_transform(img) for img in images])

Distillation Loss:

Here, we would want to use some distance metric to calculate the loss between the student output and the teacher output. We do as following:

  1. Get the Softmax of the centered teacher’s prediction of its output and then apply sharpening.
  2. Get the Softmax of student prediction and then apply sharpening.
def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
"""
Calculates distillation loss with centering and sharpening (function H in pseudocode).
"""
# Detach teacher output to stop gradients.
teacher_output = teacher_output.detach()

# Center and sharpen teacher's outputs
teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)

# Sharpen student's outputs
student_probs = F.log_softmax(student_output / tau_s, dim=1)

# Calculate cross-entropy loss between students' and teacher's probabilities.
loss = - (teacher_probs * student_probs).sum(dim=1).mean()
return loss

Centering: Centering the teacher’s output ensures that the student model focuses more on the most prominent features or distinctions within the teacher’s output distribution. By Centering the distribution, the student is encouraged to pay more attention to the salient features that are crucial for accurate prediction, rather than being influenced by variations or biases in the data. This helps in more effective knowledge transfer and can lead to improved performance of the student model.

Sharpening: Sharpening involves amplifying specific features within the data distribution, aiming to emphasize the distinctions highlighted by the teacher model. This process enables the student model to focus on learning the intricate details present in the teacher’s predictions, which are crucial for accurately replicating its output across the dataset.

NOTE: I highly recommend that you go through the first 10 mins of this video https://www.youtube.com/watch?v=BFivrO_PXt4 to get a deeper understanding of Centering and Sharpening as these are some important methods utilized in deep learning.

Training the DINO Model:

Image elucidating DINO pseudocode, taken from the official paper

There are 3 important steps to strees upon:

1. Getting Augmentation for different inputs (x1, x2) for the student and teacher architecture.

2. The Distillation Loss function, that we talked about earlier, notice how it calculates distillation loss of architectures with different augmentation of inputs i.e. gs({x1, x2}) and gt({x1, x2})

3. Updating the (a) Student Parameters (b) Teachers Parameters & (c) Center The key thing here is that we do an Exponential Moving Average update for updating the teacher parameters.

  • Teacher Parameters: EMA is applied to the parameters of the teacher model. Instead of directly updating the teacher parameters with each iteration of training, EMA maintains a moving average of these parameters over time. This moving average serves as a sdmoother, more stable representation of the teacher model, which can help guide the training of the student model.
  • Center: Additionally, in some implementations of DINO, EMA is also applied to update the center. The center represents the mean of the teacher’s output distribution, which is used for normalization purposes. By applying EMA to update the center, it evolves gradually throughout training, providing a more stable reference point for normalization.

DINO Model

Here is how the final DINO model implementation will look

class DINO(nn.Module):
def __init__(self, student_arch: Callable, teacher_arch: Callable, device: torch.device):
"""
Args:
student_arch (nn.Module): ViT Network for student_arch
teacher_arch (nn.Module): ViT Network for teacher_arch
device: torch.device ('cuda' or 'cpu')
"""
super(DINO, self).__init__()

self.student = student_arch().to(device)
self.teacher = teacher_arch().to(device)
self.teacher.load_state_dict(self.student.state_dict())

# Initialize center as buffer to avoid backpropagation
self.register_buffer('center', torch.zeros(1, student_arch().output_dim))

# Ensure the teacher parameters do not get updated during backprop
for param in self.teacher.parameters():
param.requires_grad = False

@staticmethod
def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
"""
Calculates distillation loss with centering and sharpening (function H in pseudocode).
"""
# Detach teacher output to stop gradients.
teacher_output = teacher_output.detach()

# Center and sharpen teacher's outputs
teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)

# Sharpen student's outputs
student_probs = F.log_softmax(student_output / tau_s, dim=1)

# Calculate cross-entropy loss between student's and teacher's probabilities.
loss = - (teacher_probs * student_probs).sum(dim=1).mean()
return loss

def teacher_update(self, beta: float):
for teacher_params, student_params in zip(self.teacher.parameters(), self.student.parameters()):
teacher_params.data.mul_(beta).add_(student_params.data, alpha=(1 - beta))

To update the teacher’s parameter we use the formula as proposed in the paper i.e. gt.param = gt.param*beta + gs.param*(1 — beta), where beta is the moving average decay and gt, gs are the respective teacher and student architecture.

Further, we see under __init__, the teacher’s parameters have been set to “required_grads = False” because we don’t want to update them during backpropagation, but rather apply a Moving Average update.

Also, initializing a variable as a bugger in PyTorch is a common method used to keep it out of the gradient graph, and not participate in backpropagation.

The Dino model further needs to be called as follows.

device = 'cuda' if torch.cuda.is_available() else 'cpu'
dino = DINO(ViT(), ViT(), device)

Here, we pass the student and teacher architecture which is nothing, but the standard vision transformer i.e. ViT-B/16 or ViT-L/16 as proposed in the very first paper. If you don’t know much about the ground ViT architecture, I recommend you follow this blog: https://medium.com/thedeephub/building-vision-transformer-from-scratch-using-pytorch-an-image-worth-16x16-words-24db5f159e27

The Final Training Loop

The entire implementation could now be fit in a training loop, as proposed in the paper.

This is how the final code will look.

def train_dino(dino: DINO,
data_loader: DataLoader,
optimizer: Optimizer,
device: torch.device,
num_epochs,
tps=0.9,
tpt= 0.04,
beta= 0.9,
m= 0.9,
):
"""
Args:
dino: DINO Module
data_loader (nn.Module): Dataloader for training
optimizer (nn.optimizer): Optimizer for optimization (SGD etc.)
defice (torch.device): 'cuda', 'cpu'
num_epochs: Number of Epochs
tps (float): tau for sharpening student logits
tpt: for sharpening teacher logits
beta (float): moving average decay
m (float): center moveing average decay
"""

for epoch in range(num_epochs):
print(f"Epoch: {epoch+1}/{len(num_epochs)}")
for x in data_loader:

x1, x2 = global_augment(x), multiple_local_augments(x)

student_output1, student_output2 = dino.student(x1.to(device)), dino.student(x2.to(device))
with torch.no_grad():
teacher_output1, teacher_output2 = dino.teacher(x1.to(device)), dino.teacher(x2.to(device))

# Compute distillation loss
loss = (dino.distillation_loss(teacher_output1, student_output2, dino.center, tps, tpt) +
dino.distillation_loss(teacher_output2, student_output1, dino.center, tps, tpt)) / 2

# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Update the teacher network parameters
dino.teacher_update(beta)

# Update the center
with torch.no_grad():
dino.center = m * dino.center + (1 - m) * torch.cat([teacher_output1, teacher_output2], dim=0).mean(dim=0)
  1. We calculate x1 and x2 with different global and local augmentations.
  2. After, that we get the output for both the student and teacher model, as proposed in the paper, recall the Algorithm Loop figure above.
  3. Here we set the torch to no_grad() function, to make sure that teachers’ parameters are not updated through backpropagation.
  4. Finally, we calculate the distillation loss, again with the methodology proposed in the paper.
  5. Within the distillation loss, we first center the outputs of the teacher model, so that the student model is less prone to just collapse and learn the unimportant features, or learn one feature more than another, but rather focus on learning the most distinct and underlying features from the teacher model.
  6. We then sharpen the features, so that while calculating the loss we are now able to compare between two features (student and teacher’s) which have very distinct distributions of data, it means that after sharpening, more important features would be sharpened whereas the less important features not, this will create a more distinct feature map, making it easy for the student to learn.
  7. We then perform backpropagation and do optimizer.step() and update the student model and update the teacher network through exponential moving average as implemented before.
  8. As the final step, we will again set the torch to no_grad() and update the center through the moving average. We update the center according to the teacher’s output, so it remains consistent with the changes in output data distribution throughout the training.

And that is it, this is how you train a DINO model from scratch. So far in the vision transformer series we’ve implemented the standard ViT, Swin, CvT, Mae, and DINO (Self-Supervised). I hope that you enjoyed reading this article.

Here’s how yo’ll fit it

# Create your own CustomDataset and dataloader
dataloader = DataLoader(CustomDataset, batch_size=32, shuffle=True)
optimizer = torch.optim.AdamW(dino.parameters(), lr=1e-4)
train_dino(dino,
DataLoader=dataloader,
Optimizer=optimizer,
device=device,
num_epochs=300,
tps=0.9,
tpt= 0.04,
beta= 0.9,
m= 0.9)

If you liked my story I request that you follow my page. I also have a GitHub repository ML-Models, I would greatly appreciate it if you could go check it out and consider giving it a star if you find it helpful, the entire code is available here: https://github.com/mishra-18/ML-Models/blob/main/Vision%20Transformers/dino.py

--

--