Fine-Tuning Vision Transformer with Hugging Face and PyTorch

supersjgk
6 min readDec 11, 2023

--

Navigating CIFAR-10 Image Classification

Introduction

Ever heard of Attention Is All You Need? It’s a big deal. Transformers, which started with text stuff, are now everywhere, even in images using something called Vision Transformers (ViTs) that were first introduced in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. These aren’t just another flashy trend; they’ve proven to be serious contenders, rivaling traditional models like Convolutional Neural Networks (CNN).

Short overview on ViT:

  • Split the images into patches, pass these patches into a fully-connected (FC) network or FC+CNN to get input embedding vectors.
  • Add positional information.
  • Pass it into the traditional transformer encoder with a FC layer attached at the end.
A snapshot of ViT architecture

This story isn’t about understanding the nitty-gritty of ViTs but is more like a guide on how to fine-tune the pretrained ViT Image Classification models using Hugging Face and PyTorch and use them for your own tasks. Stick with me, and I’ll help you make these transformers work for whatever you need!

Problem Statement

Our goal is to utilize a pretrained Vision Transformer model for image classification on the CIFAR-10 dataset*. However, the challenge lies in the mismatch between the size and the number of output classes of the dataset used for training the model and the target dataset. To address this, we employ Fine Tuning.

The model we will use is google/vit-base-patch16–224 *. This model has been trained on ImageNet-21k (14 million images, 21,843 classes) and fine-tuned on ImageNet-1k (1 million images, 1,000 classes). It utilizes a patch size of 16x16 and processes images of size 3x224x224.

Our objective is to further fine-tune it on the CIFAR-10 dataset, which has only 10 output classes and images of size 3x32x32. This tutorial serves as a starting point for fine-tuning any ViT present in the Hugging Face library for a wide range of tasks.

* Any dataset/model can be used with appropriate tweaking.

Setting up the environment

You can use either Jupyter or Google Colab. Install and import the necessary libraries and frameworks.

!pip install torch torchvision
!pip install transformers datasets
!pip install transformers[torch]
# PyTorch
import torch
import torchvision
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
# For dislaying images
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
# Loading dataset
from datasets import load_dataset
# Transformers
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
# Matrix operations
import numpy as np
# Evaluation
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

Data Preprocessing

Only a small subset of dataset is used for demonstration purpose. Split the data into training, validation and testing datasets

trainds, testds = load_dataset("cifar10", split=["train[:5000]","test[:1000]"])
splits = trainds.train_test_split(test_size=0.1)
trainds = splits['train']
valds = splits['test']
trainds, valds, testds
# Output
(Dataset({
features: ['img', 'label'],
num_rows: 4500
}),
Dataset({
features: ['img', 'label'],
num_rows: 500
}),
Dataset({
features: ['img', 'label'],
num_rows: 1000
}))

Note: Blocks marked with # Output in the first line are not a part of the code, they just contain the outputs for visualization.

If you’re not familiar with datasets package, here’s how you can access an item.

trainds.features, trainds.num_rows, trainds[0]
# Output
({'img': Image(decode=True, id=None),
'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], id=None)},
4500,
{'img': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
'label': 0})

Now let’s map the integer labels to string labels and vice-versa.

itos = dict((k,v) for k,v in enumerate(trainds.features['label'].names))
stoi = dict((v,k) for k,v in enumerate(trainds.features['label'].names))
itos
# Output
{0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'}

Now, let’s display an image and corresponding label from the dataset.

index = 0
img, lab = trainds[index]['img'], itos[trainds[index]['label']]
print(lab)
img
Airplane: 3x32x32 image

Now, let’s do some Image processing using Hugging Face and PyTorch. We use ViTImageProcessor to handle Image-to-patch conversion (image Tokenizer) and normalization.

model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)

mu, sigma = processor.image_mean, processor.image_std #get default mu,sigma
size = processor.size

We use TorchVision transforms pipeline. Other transforms can be used to fit the needs of your data.

norm = Normalize(mean=mu, std=sigma) #normalize image pixels range to [-1,1]

# resize 3x32x32 to 3x224x224 -> convert to Pytorch tensor -> normalize
_transf = Compose([
Resize(size['height']),
ToTensor(),
norm
])

# apply transforms to PIL Image and store it to 'pixels' key
def transf(arg):
arg['pixels'] = [_transf(image.convert('RGB')) for image in arg['img']]
return arg

Apply the transformations to each dataset.

trainds.set_transform(transf)
valds.set_transform(transf)
testds.set_transform(transf)

To view a transformed image, run the following code snippet:

idx = 0
ex = trainds[idx]['pixels']
ex = (ex+1)/2 #imshow requires image pixels to be in the range [0,1]
exi = ToPILImage()(ex)
plt.imshow(exi)
plt.show()
Transformed Airplane: 3x224x224

Fine-Tuning Model

We use Hugging Face’s ViTForImageClassification which takes images as input and outputs the predictions of classes. Let’s first see what the original model’s classifier looks like.

model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
print(model.classifier)
# Output
Linear(in_features=768, out_features=1000, bias=True)

It outputs probabilities for 1000 classes as it should because it was originally fine-tuned on ImageNet-1k.

We can fine tune it to output 10 classes using the following parameters: num_labels which basically changes the number of nodes in final linear layer, ignore_mismatched_sizes because originally it would have 1000 output nodes but now we’ll just have 10, and the mapping of label indices and label strings.

model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)
print(model.classifier)
# Output
Linear(in_features=768, out_features=10, bias=True)

Hugging Face Trainer

Trainer provides a high level of abstraction and simplifies the training and evaluation.

Let’s start of with training arguments where you can define hyperparameters, logging, metrics, etc.

args = TrainingArguments(
f"test-cifar-10",
save_strategy="epoch",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=10,
per_device_eval_batch_size=4,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir='logs',
remove_unused_columns=False,
)

Now, we need a collate function that is used in data loading. It stacks pixel values into a tensor and creates a tensor for labels. The model needs pixel_values and labels in a batch of inputs so do not change the names of these tensors.

We also need a function to compute metrics. In our case, we’ll use accuracy. I recommend passing a sample input into these functions and printing the values to better understand them.

def collate_fn(examples):
pixels = torch.stack([example["pixels"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixels, "labels": labels}

def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return dict(accuracy=accuracy_score(predictions, labels))

Now, pass model, training arguments, datasets, collate function, metric function, and the Image processor we defined earlier into Trainer:

trainer = Trainer(
model,
args,
train_dataset=trainds,
eval_dataset=valds,
data_collator=collate_fn,
compute_metrics=compute_metrics,
tokenizer=processor,
)

Training the model

We have to train the last layer that we fine-tuned while keeping all other layers freezed. Do this by simply calling:

trainer.train()

After training is completed, you can see the logs and an output like this:

# Output
TrainOutput(global_step=675, training_loss=0.22329048227380824, metrics={'train_runtime': 1357.9833, 'train_samples_per_second': 9.941, 'train_steps_per_second': 0.497, 'total_flos': 1.046216869705728e+18, 'train_loss': 0.22329048227380824, 'epoch': 3.0})

Evaluation

outputs = trainer.predict(testds)
print(outputs.metrics)
# Output
{'test_loss': 0.07223748415708542, 'test_accuracy': 0.973, 'test_runtime': 28.5169, 'test_samples_per_second': 35.067, 'test_steps_per_second': 4.383}

Here’s how to access the outputs:

itos[np.argmax(outputs.predictions[0])], itos[outputs.label_ids[0]]

To plot the confusion matrix use the following code:

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

labels = trainds.features['label'].names
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(xticks_rotation=45)
Confusion matrix

Code Notebook

HERE’S THE CODE

Meme

--

--

supersjgk

Computer Scientist # This comment is here to create the illusion of documentation. # Want me to code something and weave a story around it? Let me know!