FaKe-ViT-B/16: Replicating and Fine-Tuning the Vision Transformer paper from scratch with PyTorch for Fake/AI-generated image detection task

Mohd Zeeshan
16 min readFeb 16, 2024

Hello Everyone!

In this article we will be implimenting the Vision Transformer architecture from scratch and directly from the research paper. Then, we will also fine-tune the pre-trained ViT-base model for an AI-generated image detection task.

Before we start, here’s the GitHub repo for the codebase and a HuggingFace demo to understand what we are going to do:

GitHub: https://github.com/zappy586/FAKE-ViT

HuggingFace Demo: https://huggingface.co/spaces/Zappy586/Fake-ViT

Introduction:

The Vision Transformer (ViT) was introduced by the Google Brain team in 2021. It represents a significant shift in the field of image classification, moving away from the traditional Convolutional Neural Networks (CNNs) that have been the de-facto model for visual data.

ViT operates almost identically to Transformers used in language processing, using self-attention, rather than convolution, to aggregate information across locations. This is a departure from a large body of prior work, which focused on incorporating image-specific inductive biases.

The key difference between ViTs and CNNs lies in their internal representation structure. ViTs have more uniform representations across all layers, enabled by self-attention, which allows early aggregation of global information, and ViT residual connections, which strongly propagate features from lower to higher layers.

ViT architecture

Forward Pass:

The authors first turn the image into a 16x16 patch. The image size used is 224x224 so the total number of patches comes out to be 196. Each of these patches is flattened into a single 768 dimensional layer known as the image’s embedding. This embedding is then added with a class token(like BERT) and a positional embedding to preserve the patch ordering. Finally, this embedding is sent into a Transformer encoder layer which has a single MLP head for classification.

Loading the Dataset:

We will be using the train split of the ‘Fake or Real Competition’ dataset which has about 7000 images of normal and AI-Generated images. Let’s load this onto our Google Colab notebook:

If you want to directly use my notebook, you can do so here. I would highly suggest you to keep this open while you follow along. Also, please leave a star on the repo if you liked my work :)

from torchvision.datasets import ImageFolder
from torchvision.transforms import v2

transforms = v2.Compose([
v2.Resize((224,224)),
v2.RandomHorizontalFlip(),
v2.ToTensor()
])

data = ImageFolder(root="/content/drive/MyDrive/train", transform=transforms)
data, data.classes
(Dataset ImageFolder
Number of datapoints: 6750
Root location: /content/drive/MyDrive/train
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=warn)
RandomHorizontalFlip(p=0.5)
ToTensor()
),
['fake_images', 'real_images'])

We will be using the ImageFolder class of the torchvision.datasets module to load the dataset. There are only 2 classes: fake_images and real_images. These are some sample images from the dataset:

Fake image
Real Image

Now we will subset the dataset into train and test splits respectively:

from torch.utils.data import Subset

data_len = len(data)
n_test = int(0.10 * data_len)
print(n_test)
test_data = Subset(data, range(n_test))
train_data = Subset(data, range(n_test, data_len))
len(train_data), len(test_data)
(6075, 675)

We will use 10% of the dataset for testing

Let’s prepare the dataloaders now:

import os
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_WORKERS = 8
train_dataloader = DataLoader(dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
drop_last=True)

test_dataloader = DataLoader(dataset=test_data,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
drop_last=True)

len(train_dataloader), len(test_dataloader), class_names

Creating Patches and Embeddings

The authors then split the images into grids of 16x16 and each grid is converted to a 768 dimensional vector using a convolution layer.

So by doing the math we can see that a 224x224 image will be converted from a tensor of size (3,224,224) into a vector of size (1,196,768). The patches can visualized something like this:

The convolutional maps of each patch look something like this:

Convolution maps of a patch

Then finally each patch is flattened into a single 768-dimensional vector which looks something like this:

Now that we have visualized what we have to do, Lets get to coding!

Lets create a patch embedding layer that does a convolution operation and a flatten operation to create the vectors. We also add the class token and positional embedding:

class PatchEmbedding(nn.Module):
def __init__(self,
in_channels:int=3,
patch_size:int=16,
embedding_dim:int=768,
batch_size:int=1,
img_height:int=224,
img_width:int=224):
super().__init__()
self.num_patches = int((img_height*img_width)/patch_size**2)
self.class_token = nn.Parameter(torch.randn(batch_size,
1,
embedding_dim),
requires_grad=True)
self.pos_embedding = nn.Parameter(torch.randn(batch_size,
self.num_patches+1,
embedding_dim),
requires_grad=True)
self.img_embedding_layer = nn.Sequential(
nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0),
nn.Flatten(start_dim=2,
end_dim=3))
def forward(self, x):
image_resolution = x.shape[-1]
assert image_resolution % patch_size == 0, f"Input image size {image_resolution} must be divisible by patch size {patch_size}"
x = self.img_embedding_layer(x)
x = x.permute(0,2,1)
x = torch.cat((self.class_token, x), dim=1)
x = self.pos_embedding + x
return x

patchify = PatchEmbedding()
patch_embedded_image = patchify(img.unsqueeze(0))
patch_embedded_image.shape

Multi-Head Self Attention Block:

A multi head self-attention block looks something like this:

Multi-Head Self Attention

MSA was introduced in the legendary “Attention is all you need(Vaswani et al.)” paper. Which also happens to be the very architecture we will be using. Self attention in simple terms projects each vector into many different learnable linear spaces which uses a scaled dot-product attention and a concat operation at the end. This output is then sent through a layer normalization layer before being added with a residual/skip connection.

Lets code the MSA block:

class MultiHeadSelfAttentionBlock(nn.Module):
def __init__(self,
embedding_dim:int=768,
num_heads:int=12,
attn_dropout:int=0):
super().__init__()
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True)
def forward(self, x):
residual = x
x = self.layer_norm(x)
attn_output, _ = self.multihead_attn(query=x,
key=x,
value=x,
need_weights=False)
# x = attn_output + residual
return attn_output
image_msa_output = MultiHeadSelfAttentionBlock()
msa_output = image_msa_output(patch_embedded_image)
msa_output.shape

MLP Block

The transformer encoder block also contains a MLP block at the end with a GeLU activation with alternating dropouts and layernorms. So lets code that out too: (MLP is basically a fully connected layer)

class MLPBlock(nn.Module):
def __init__(self,
embedding_dim:int=768,
mlp_size:int=3072,
dropout:int=0.1):
super().__init__()
self.layer_1 = nn.Linear(in_features=embedding_dim,
out_features=mlp_size)
self.gelu = nn.GELU()
self.dropout = nn.Dropout(p=dropout)
self.layer_2 = nn.Linear(in_features=mlp_size,
out_features=embedding_dim)
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x = self.layer_norm(x)
x = self.layer_1(x)
x = self.gelu(x)
x = self.dropout(x)
x = self.layer_2(x)
x = self.dropout(x)
return x
mlp_layer = MLPBlock()
mlp_output = mlp_layer(msa_output)
mlp_output.shape

Transformer-Encoder Block

The transformer-encoder block then puts these both blocks together to form the transformer-encoder layer. Lets code it out:

class TransformerEndcoderBlock(nn.Module):
def __init__(self,
embedding_dim:int=768,
num_heads:int=12,
mlp_size:int=3072,
mlp_dropout:int=0.1,
attn_dropout:int=0):
super().__init__()
self.msa_block = MultiHeadSelfAttentionBlock(embedding_dim=embedding_dim,
num_heads=num_heads,
attn_dropout=attn_dropout)

self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
mlp_size=mlp_size,
dropout=mlp_dropout)

def forward(self, x):
x = self.msa_block(x) + x
x = self.mlp_block(x) + x
return x
=======================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
=======================================================================================================================================
TransformerEndcoderBlock (TransformerEndcoderBlock) [1, 197, 768] [1, 197, 768] -- True
├─MultiHeadSelfAttentionBlock (msa_block) [1, 197, 768] [1, 197, 768] -- True
│ └─LayerNorm (layer_norm) [1, 197, 768] [1, 197, 768] 1,536 True
│ └─MultiheadAttention (multihead_attn) -- [1, 197, 768] 2,362,368 True
├─MLPBlock (mlp_block) [1, 197, 768] [1, 197, 768] -- True
│ └─LayerNorm (layer_norm) [1, 197, 768] [1, 197, 768] 1,536 True
│ └─Linear (layer_1) [1, 197, 768] [1, 197, 3072] 2,362,368 True
│ └─GELU (gelu) [1, 197, 3072] [1, 197, 3072] -- --
│ └─Dropout (dropout) [1, 197, 3072] [1, 197, 3072] -- --
│ └─Linear (layer_2) [1, 197, 3072] [1, 197, 768] 2,360,064 True
│ └─Dropout (dropout) [1, 197, 768] [1, 197, 768] -- --
=======================================================================================================================================
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
Total mult-adds (M): 4.73
=======================================================================================================================================
Input size (MB): 0.61
Forward/backward pass size (MB): 8.47
Params size (MB): 18.90
Estimated Total Size (MB): 27.98
=======================================================================================================================================

So far, this is what a single encoder layer looks like. Its a 7 million parameter model for now.

We can also do all of this with a single line of code thanks to the PyTorch team who abstracted it behind a class:

# Creating a transformer encoder layer with in-built PyTorch layers

torch_transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=768,
nhead=12,
dim_feedforward=3072,
dropout=0.1,
activation="gelu",
batch_first=True,
norm_first=True,
device=device)

model_summary(model=torch_transformer_encoder_layer,
input_size=(1,197,768))
==================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
==================================================================================================================================
TransformerEncoderLayer (TransformerEncoderLayer) [1, 197, 768] [1, 197, 768] 7,087,872 True
==================================================================================================================================
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
Total mult-adds (M): 0
==================================================================================================================================
Input size (MB): 0.61
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.61
==================================================================================================================================

This has the same amount of parameters but for a fraction of size, i.e, 0.61 MB!

This is due to some under-the-hood optimizations from the PyTorch team.

Putting it all Together!

Yay! we’ve built the biggest part of the model, The transformer. Now we just have to put it together inside a single class along with a classifier head at the end:

class ViTModel(nn.Module):
def __init__(self,
batch_size:int=32,
img_size:int=224,
patch_size:int=16,
in_channels:int=3,
num_transformer_layers=12,
embedding_dim:int=768,
num_heads:int=12,
attn_dropout:float=0.1,
mlp_dropout:float=0.1,
device:torch.device=device,
mlp_size:int=3072,
num_classes:int=1000):
super().__init__()

self.patch_embeddings = PatchEmbedding(in_channels=in_channels,
patch_size=16,
embedding_dim=embedding_dim,
batch_size=batch_size,
img_height=img_size,
img_width=img_size)

self.transformer_encoder = nn.Sequential(*[nn.TransformerEncoderLayer(d_model=embedding_dim,
nhead=num_heads,
dim_feedforward=mlp_size,
dropout=attn_dropout,
activation="gelu",
batch_first=True,
norm_first=True,
device=device) for _ in range(num_transformer_layers)])
self.classifier = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim, out_features=num_classes)
)
def forward(self, x):
x = self.patch_embeddings(x)
x = self.transformer_encoder(x)
x = self.classifier(x[:,0])
return x

If you notice it, we only took the 1st dimension of embeddings. This is what the authors did and likely due to the performance tradeoff since if we use every dimension, the model will become extremely slow for a little accuracy increase.

Let’s instantiate the model and see whats under the hood:

model_0 = ViTModel(batch_size=32,
num_classes=2).to(device)
input_size = (32, 3, 224, 224)
model_summary(model=model_0,
input_size=input_size)
=============================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
=============================================================================================================================
ViTModel (ViTModel) [32, 3, 224, 224] [32, 2] -- True
├─PatchEmbedding (patch_embeddings) [32, 3, 224, 224] [32, 197, 768] 4,866,048 True
│ └─Sequential (img_embedding_layer) [32, 3, 224, 224] [32, 768, 196] -- True
│ │ └─Conv2d (0) [32, 3, 224, 224] [32, 768, 14, 14] 590,592 True
│ │ └─Flatten (1) [32, 768, 14, 14] [32, 768, 196] -- --
├─Sequential (transformer_encoder) [32, 197, 768] [32, 197, 768] -- True
│ └─TransformerEncoderLayer (0) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (1) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (2) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (3) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (4) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (5) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (6) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (7) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (8) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (9) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (10) [32, 197, 768] [32, 197, 768] 7,087,872 True
│ └─TransformerEncoderLayer (11) [32, 197, 768] [32, 197, 768] 7,087,872 True
├─Sequential (classifier) [32, 768] [32, 2] -- True
│ └─LayerNorm (0) [32, 768] [32, 768] 1,536 True
│ └─Linear (1) [32, 768] [32, 2] 1,538 True
=============================================================================================================================
Total params: 90,514,178
Trainable params: 90,514,178
Non-trainable params: 0
Total mult-adds (G): 3.70
=============================================================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 38.73
Params size (MB): 2.37
Estimated Total Size (MB): 60.37
=============================================================================================================================

There are 90 Million parameters in our model. That means for every forward pass the model will calculate 90 million weights and during backpropagation the model will modify 90 million weights!🤯 How cool is that?

Setting up Loss function and optimizer

The loss function we will use is CrossEntropyLoss since this is a binary classification task and for optimizer we will use Adam with 0.1 weight decay. Also, this is what the authors used for their training:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(),
lr=0.001,
betas=(0.9, 0.999),
weight_decay=0.1)

Setting up Training/Optimization loop

try:
from torchmetrics.classification import BinaryAccuracy as Accuracy
except:
!pip install torchmetrics
from torchmetrics.classification import BinaryAccuracy as Accuracy
accuracy = Accuracy().to(device)
def train_step(dataloader:torch.utils.data.DataLoader,
model:torch.nn.Module,
loss_fn:torch.nn,
optimizer:torch.optim):
model.train()
train_loss, train_acc = 0, 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
y_logits = model(X)
y_preds = torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
loss = loss_fn(y_logits, y)
train_loss += loss.item()
acc = accuracy(y_preds, y)
train_acc += acc
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = train_loss/len(dataloader)
train_acc = train_acc/len(dataloader)

return train_loss, train_acc
def test_step(dataloader:torch.utils.data.DataLoader,
model:torch.nn.Module,
loss_fn:torch.nn):
model.eval()
test_loss, test_acc = 0,0
with torch.inference_mode():
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
test_logits = model(X)
loss = loss_fn(test_logits, y)
test_loss += loss.item()
test_preds = torch.argmax(torch.softmax(test_logits,dim=1),dim=1)
acc = accuracy(test_preds, y)
test_acc += acc
test_loss = test_loss / len(dataloader)
test_acc = test_acc / len(dataloader)
return test_loss, test_acc
from tqdm.auto import tqdm
def train(model:torch.nn.Module,
train_dataloader:torch.utils.data.DataLoader,
test_dataloader:torch.utils.data.DataLoader,
loss_fn:torch.nn,
optimizer:torch.optim,
epochs:int):
results = {
"train_loss": [],
"train_acc": [],
"test_loss": [],
"test_acc": []
}
for epoch in tqdm(range(epochs)):
train_loss, train_acc = train_step(dataloader=train_dataloader,
model=model,
loss_fn=loss_fn,
optimizer=optimizer)
test_loss, test_acc = test_step(dataloader=test_dataloader,
model=model,
loss_fn=loss_fn)
results['train_loss'].append(train_loss)
results['train_acc'].append(train_acc)
results['test_loss'].append(test_loss)
results['test_acc'].append(test_acc)
print(f"Epoch: {epoch} | Train Loss: {train_loss:.5f} | Train Accuracy: {train_acc:.2f}% | Test Loss: {test_loss:.5f} | Test Accuracy: {test_acc:.2f}%")
return results

Let’s start the training!

# Start the training
results = train(model=model_0,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=optimizer,
epochs=3)
results
0%|          | 0/3 [00:00<?, ?it/s]
Epoch: 0 | Train Loss: 0.75000 | Train Accuracy: 0.51% | Test Loss: 0.79582 | Test Accuracy: 0.55%
Epoch: 1 | Train Loss: 0.75157 | Train Accuracy: 0.51% | Test Loss: 0.79581 | Test Accuracy: 0.55%
Epoch: 2 | Train Loss: 0.75964 | Train Accuracy: 0.50% | Test Loss: 0.79590 | Test Accuracy: 0.55%
{'train_loss': [0.7499971547454753, 0.751566748455088, 0.7596368442767512],
'train_acc': [tensor(0.5056, device='cuda:0'),
tensor(0.5086, device='cuda:0'),
tensor(0.4959, device='cuda:0')],
'test_loss': [0.7958205030077979, 0.795807231040228, 0.7959008500689552],
'test_acc': [tensor(0.5506, device='cuda:0'),
tensor(0.5506, device='cuda:0'),
tensor(0.5476, device='cuda:0')]}
Training and testing metrics

Hmmm… seems like our model got a 55 percent accuracy. This might look good but this is a binary classification task so our model is basically guessing the output. This isn’t right. Lets see what the authors have to say about this:

“When trained on mid-sized datasets such as ImageNet without strong regularization, these models yield modest accuracies of a few percentage points below ResNets of comparable size. This seemingly discouraging outcome may be expected: Transformers lack some of the inductive biases inherent to CNNs, such as translation equivariance and locality, and therefore do not generalize well when trained on insufficient amounts of data.
However, the picture changes if the models are trained on larger datasets (14M-300M images). We find that large scale training trumps inductive bias. Our Vision Transformer (ViT) attains excellent results when pre-trained at sufficient scale and transferred to tasks with fewer datapoints. When pre-trained on the public ImageNet-21k dataset or the in-house JFT-300M dataset, ViT approaches or beats state of the art on multiple image recognition benchmarks. In particular, the best modelreaches the accuracy of 88.55% on ImageNet, 90.72% on ImageNet-ReaL, 94.55% on CIFAR-100, and 77.63% on the VTAB suite of 19 tasks.”

Ah! So this was the problem. Unlike traditional CNNs and ResNets, Transformers lack an inductive bias that causes them to perform pretty bad with less data. But if the data is scaled to extremely large amounts, The ViT model blows every other SoTA model out of the water!

So, This means we will have to fine-tune an existing pretrained model for our task. This is also known as Transfer Learning.

Fine-Tuning ViT-B/16

Let’s import the model from the built-in models module in torchvision:

pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT

pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

for parameter in pretrained_vit.parameters():
parameter.requires_grad=False

model_summary(model=pretrained_vit,
input_size=(32,3,224,224))
============================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
============================================================================================================================================
VisionTransformer (VisionTransformer) [32, 3, 224, 224] [32, 1000] 768 False
├─Conv2d (conv_proj) [32, 3, 224, 224] [32, 768, 14, 14] (590,592) False
├─Encoder (encoder) [32, 197, 768] [32, 197, 768] 151,296 False
│ └─Dropout (dropout) [32, 197, 768] [32, 197, 768] -- --
│ └─Sequential (layers) [32, 197, 768] [32, 197, 768] -- False
│ │ └─EncoderBlock (encoder_layer_0) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_1) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_2) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_3) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_4) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_5) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_6) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_7) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_8) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_9) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_10) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_11) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ └─LayerNorm (ln) [32, 197, 768] [32, 197, 768] (1,536) False
├─Sequential (heads) [32, 768] [32, 1000] -- False
│ └─Linear (head) [32, 768] [32, 1000] (769,000) False
============================================================================================================================================
Total params: 86,567,656
Trainable params: 0
Non-trainable params: 86,567,656
Total mult-adds (G): 5.54
============================================================================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 3330.99
Params size (MB): 232.27
Estimated Total Size (MB): 3582.53
============================================================================================================================================

As you can see, the model is loaded and we have made all the parameters untrainable.

Lets update the classifier head:

pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names))
model_summary(model=pretrained_vit,
input_size=(32,3,224,224))
============================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
============================================================================================================================================
VisionTransformer (VisionTransformer) [32, 3, 224, 224] [32, 2] 768 Partial
├─Conv2d (conv_proj) [32, 3, 224, 224] [32, 768, 14, 14] (590,592) False
├─Encoder (encoder) [32, 197, 768] [32, 197, 768] 151,296 False
│ └─Dropout (dropout) [32, 197, 768] [32, 197, 768] -- --
│ └─Sequential (layers) [32, 197, 768] [32, 197, 768] -- False
│ │ └─EncoderBlock (encoder_layer_0) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_1) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_2) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_3) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_4) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_5) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_6) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_7) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_8) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_9) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_10) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_11) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ └─LayerNorm (ln) [32, 197, 768] [32, 197, 768] (1,536) False
├─Linear (heads) [32, 768] [32, 2] 1,538 True
============================================================================================================================================
Total params: 85,800,194
Trainable params: 1,538
Non-trainable params: 85,798,656
Total mult-adds (G): 5.52
============================================================================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 3330.74
Params size (MB): 229.20
Estimated Total Size (MB): 3579.20
============================================================================================================================================

There we go! we have freezed all the parameters except the new classifier head.

Data transforms and loading

The authors have suggested certain specific image transforms for fine-tuning. So we will define those transforms first and then create the dataset followed by the dataloaders:

vit_transforms = pretrained_vit_weights.transforms()
data_pretrained = ImageFolder(root="/content/drive/MyDrive/train", transform=vit_transforms)
data_len = len(data_pretrained)
n_test = int(0.10 * data_len)
print(n_test)
test_data = Subset(data_pretrained, range(n_test))
train_data = Subset(data_pretrained, range(n_test, data_len))

import os
from torch.utils.data import DataLoader
BATCH_SIZE = 32
NUM_WORKERS = 8
train_dataloader = DataLoader(dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
drop_last=True)

test_dataloader = DataLoader(dataset=test_data,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
drop_last=True)

len(train_dataloader), len(test_dataloader), class_names

Setting up the loss function and optimizer:

pretrain_optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
lr=0.001,
betas=(0.9, 0.999),
weight_decay=0.1)

Starting the training!

Finally, lets start the training of this pretrained model:

pretrain_results = train(model=pretrained_vit,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
loss_fn=loss_fn,
optimizer=pretrain_optimizer,
epochs=3)
pretrain_results
 0%|          | 0/3 [00:00<?, ?it/s]
Epoch: 0 | Train Loss: 0.24115 | Train Accuracy: 0.92% | Test Loss: 0.23899 | Test Accuracy: 0.90%
Epoch: 1 | Train Loss: 0.23734 | Train Accuracy: 0.92% | Test Loss: 0.28184 | Test Accuracy: 0.87%
Epoch: 2 | Train Loss: 0.23639 | Train Accuracy: 0.92% | Test Loss: 0.25944 | Test Accuracy: 0.88%
{'train_loss': [0.24115350941028543, 0.23734185647554498, 0.2363924035634944],
'train_acc': [tensor(0.9152, device='cuda:0'),
tensor(0.9168, device='cuda:0'),
tensor(0.9162, device='cuda:0')],
'test_loss': [0.23899360994497934, 0.2818390883150555, 0.2594364704120727],
'test_acc': [tensor(0.8988, device='cuda:0'),
tensor(0.8690, device='cuda:0'),
tensor(0.8765, device='cuda:0')]}
Training and Testing curves of the fine-tuned ViT model

The testing accuracy of this model comes out to be 87%. Thats a 30% accuracy bump from our base model. That’s amazing!

Deploying our model:

Let’s deploy our model on HuggingFace spaces using Gradio. The following is the code for app.py:

import gradio as gr
from PIL import Image
import torch
import torchvision.models as models
from torchvision.transforms import v2 as transforms
import os

# Define the class names
class_names = ['Fake/AI-Generated Image', "Real/Not an AI-Generated Image"]

# Load the model
weights_path = "FaKe-ViT-B16.pth"
model = torch.load(weights_path, map_location=torch.device('cpu'))
model.eval()
# Preprocessing the image
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the prediction function
def predict_image(image):
image = preprocess(image)
if image.shape[0] != 3:
# image = image[:3, :, :]
return "Invalid Image: Image should be in RGB format. Please upload a valid image."
image = image.unsqueeze(0)
with torch.inference_mode():
output = model(image)
output1 = torch.argmax(torch.softmax(output,dim=1),dim=1).item()
return class_names[output1]



demo = gr.Interface(
predict_image,
gr.Image(image_mode="RGB",type="pil"),
"text",
flagging_options=["incorrect prediction"],
examples=[
("images/cheetah.jpg"),
( "images/cat.jpg"),
("images/astronaut.jpg"),
("images/mountain.jpg"),
("images/unicorn.jpg")
],
title="<u>FaKe-ViT-B/16: Robust and Fast AI-Generated Image Detection using Vision Transformer(ViT-B/16):</u>",
description="<p style='font-size: 20px;'>This is a demo to detect AI-Generated images using a fine-tuned Vision Transformer(ViT-B/16). Upload an image and the model will predict whether the image is AI-Generated or Real",
article="<p style='font-size: 20px;'><b>Paper</b>: 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale', Alexey et al.<br/><b>Dataset</b>: 'Fake or Real competition dataset' at <a href='https://huggingface.co/datasets/mncai/Fake_or_Real_Competition_Dataset'>Fake or Real competition dataset</a>"
)

if __name__ == "__main__":
demo.launch()

I have stored some of the example images within the ‘images’ folder in the root directory. This is what the demo looks like after its deployed:

Conclusion

So in conclusion, we have built a Vision Transformer model from scratch and also fine-tuned a pre-trained one for a task to detect AI-generated images.

If you did it successfully, Give a pat on your back! Replicating a research paper that likely took months or years to make from the smartest people on Earth is no easy feat. Even if you couldn’t, that’s fine too.

But I do hope you all got to learn from this short article and project of mine :)

Next time we will be replicating the very recent Omnivec paper which is breaking the records on almost every classification task leaderboard ranging from image, text, audio, 3D and more! Since there is no codebase for it, it will be challenging but nothing that we can’t handle!

Until then, See you next time!

--

--

Mohd Zeeshan

AI/ML Research, NLP , and RAG-based applications developer