Image Classification using Huggingface ViT

Kenji Tee
11 min readAug 3, 2021

--

For the longest time, Convolutional Neural Network(CNN) have been used to perform image classification. However with the new state-of-the-art Hugging Face Vision Transformer(ViT), solving image classification problems with Transformers has never been easier. Today I will be showing you a step by step tutorial on how to fine-tune this Vision Transformers with Pytorch. In my example I will be using the ASL dataset…

First you will have to import and install a few packages:

! pip install transformers pytorch-lightning --quiet
! sudo apt -qq install git-lfs
import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, UnidentifiedImageError
from pathlib import Path
import torch
import glob
import pytorch_lightning as pl
from huggingface_hub import HfApi, Repository
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy
from transformers import ViTFeatureExtractor,
ViTForImageClassification
from pytorch_lightning.callbacks import ModelCheckpoint

In this example, the dataset would be from Kaggle, but you can use datasets from various sources. Below is a quick guide on getting Kaggle datasets on Google Colab. For a more detailed tutorial on this, check this link out https://galhever.medium.com/how-to-import-data-from-kaggle-to-google-colab-8160caa11e2

!pip install -q kagglefrom google.colab import files
files.upload()

Upload your Kaggle.json’s file. If you don’t have it yet, you can simply get it on Kaggle. View your account page on Kaggle and select “Account” on the navigation bar . You should now see your User ID, User Name, etc. If you scroll a little lower you should see the API section, click on “Create a New API Token.” Once your file is uploaded, you would want to copy the kaggle.json file into ~/.kaggle because that is where the Kaggle API client expects the file to be at. The line of code below does that for you.

! cp kaggle.json ~/.kaggle/

Then you would need to change the permissions of the file

! chmod 600 ~/.kaggle/kaggle.json

Once that is done, you have access to all of kaggle’s dataset. To take a look at the list of dataset, run the code below.

! kaggle datasets list

From the dataset, I’ll be downloading the ASL Alphabet dataset.(Note that if you are downloading data from a Kaggle Competition, you would first need to accept the Competition Rules on Kaggle)

!kaggle datasets download -d grassknoted/asl-alphabet

Next we would have to unzip the file.

!unzip asl-alphabet.zip

Before I continue on, let me explain to you how the files are placed. After unzipping the file, there will be two folders, the “asl_alphabet_test” folder and the “asl_alphabet_train” folder. Within the train folder are subfolders named after their classes, my subfolders are therefore named ‘a’, ‘b’, ‘c’ …‘z’. Each subfolder consists of images corresponding to the asl alphabet of that particular folder name. Below are a few example images you’ll find in the ‘a’ folder name.

Figure 1. Examples of images in folder A.

The test folder on the other hand just consist of unclassified photos thus there is no need for subfolders. As shown below is the example of how the test folder looks like.

Figure 2. Example of test folder.

If you’d like to upload your own datasets, do ensure that your file structure is similar.

Create a directory path to your datasets.

data_dir = Path("/content/asl_alphabet_train/asl_alphabet_train")

We then load the dataset using the ImageFolder class.The pytorch ImageFolder class expects the data to be organized in the way shown below.

Figure 3. from PyTorch, Examples of how to path file should be when using ImageFolder class

Once that’s done, you would want to randomize the dataset. We set the n_val to 15% of the total length of the indices, therefore the remaining 85% will be for the training data. We will use torch.utils.data.Subsets to help us split the dataset automatically.

ds=ImageFolder(data_dir)
indices = torch. randperm(len(ds)).tolist()
n_val = math.floor(len(indices) * .15)
train_ds = torch.utils.data.Subset(ds, indices[:-n_val])
val_ds = torch.utils.data.Subset(ds, indices[-n_val:])

To test out if your images look fine, you can run these few lines of code below:

plt.figure(figsize=(100,50))
num_examples_per_class = 1
i = 1
for class_idx, class_name in enumerate(ds.classes):
folder = ds.root / class_name
for image_idx, image_path in enumerate(sorted(folder.glob('*'))):
if image_path.suffix in ds.extensions:
image = Image.open(image_path)
plt.subplot(len(ds.classes), num_examples_per_class, i)
ax = plt.gca()
ax.set_title(
class_name,
size='xx-large',
pad=5,
loc='left',
y=0,
backgroundcolor='white'
)
ax.axis('off')
plt.imshow(image)
i += 1
if image_idx + 1 == num_examples_per_class:
break

What this line of code above does is, it first creates a figure using the plt.figure with the specific figure size. Then I run a for loop over the ds.classes to create a variable called ‘folder’ which is the root plus the class name for example: ‘/content/asl_alphabet_train/asl_alphabet_train/A’. Then within that for loop, we create another for loop to loop through the individual images inside of that ‘folder.’ We check if the image_path.suffix is in the ds.extensions which are: (‘.jpg’, ‘.jpeg’, ‘.png’, ‘.ppm’, ‘.bmp’, ‘.pgm’, ‘.tif’, ‘.tiff’, ‘.webp’), if it is we will open the image. Then using matlab plotting library we will create a subplot of all the images. the plt.subplot takes in three parameters in the order of number of rows, number of columns and the index respectively. Then using plt.gca() you’ll be able to get the current polar axes of the figure, and you can set titles, plt.imshow() displays the image.

Create label dictionaries for our model configurations.

label2id = {}
id2label = {}
for i, class_name in enumerate(ds.classes):
label2id[class_name] = str(i)
id2label[str(i)] = class_name

For preprocessing we would be creating a custom image classification collator which would help us collate batches. The encodings runs through the feature_extractor then for x in batch, it sets x[0] to be the image and converts it into PyTorch tensors, and sets encodings[‘labels’] to be x[1] for x in batch and saves it as a torch.long

class ImageClassificationCollator:
def __init__(self, feature_extractor):
self.feature_extractor = feature_extractor
def __call__(self, batch):
encodings = self.feature_extractor([x[0] for x in batch],
return_tensors='pt')
encodings['labels'] = torch.tensor([x[1] for x in batch],
dtype=torch.long)
return encodings

Using the HuggingFace ViTFeatureExtractor, we will extract the pretrained input features from the ‘google/vit-base-patch16–224-in21k’ model and then prepare the image to be passed through our custom image collator. The collator instance will be used as the parameter called collate_fn in the Pytorch DataLoader. The DataLoader helps to parallelize the data loading and automatically helps to make batches from the dataset. The batch_size argument is used to specify how many samples we want per batch. The collate_fn parameter is used to customize collation to the batches while the num_worker allows us to carry our multi-process data loading by setting the argument to an integer. By default, the num_worker value is set to zero, and a value of zero tells the loader to load the data inside the main process. This means that the training process will work sequentially inside the main process. After a batch is used during the training process and another one is needed, we read the batch data from disk. Now, if we have a worker process, we can make use of the fact that our machine has multiple cores. This means that the next batch can already be loaded and ready to go by the time the main process is ready for another batch. This is where the speed up comes from. The batches are loaded using additional worker processes and are queued up in memory. Depending on how many cores your computer has and if you’re using CPU,GPU core or TPU, the value of num_worker differs. Setting the shuffle argument to be true, enables the data to be automatically shuffled after iterating through all the batches. We then load the VITForImageClassification pretrained model to our variable model.

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
collator = ImageClassificationCollator(feature_extractor)
train_loader = DataLoader(train_ds, batch_size=32,
collate_fn=collator, num_workers=2, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, collate_fn=collator,
num_workers=2)
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=len(label2id),
label2id=label2id,
id2label=id2label)

Now we will be creating our Classifier class to fine-tune our model. Our Classifier will take in two arguments, which are the model and the learning rate. This class will carry out the training steps, validation steps as well as configuring of the optimizer. We will be using AdamW optimizer for this example. On a side note, after carrying out the training, I realized that my model was overfitted so I added a weight decay. Intuitively how weight decay works is by adding a L2 penalty to the cost which can effectively lead to smaller model weight.

class Classifier(pl.LightningModule):   def __init__(self, model, lr: float = 2e-5, **kwargs): 
super().__init__()
self.save_hyperparameters('lr', *list(kwargs))
self.model = model
self.forward = self.model.forward
self.val_acc = Accuracy()
def training_step(self, batch, batch_idx):
outputs = self(**batch)
self.log(f"train_loss", outputs.loss)
return outputs.loss
def validation_step(self, batch, batch_idx):
outputs = self(**batch)
self.log(f"val_loss", outputs.loss)
acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
self.log(f"val_acc", acc, prog_bar=True)
return outputs.loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(),
lr=self.hparams.lr,weight_decay = 0.00025)

Before we train our model, we are going to first set the seed for pseudo-random number generators to be 42 (feel free to use any number) by using the pl.seed_everything(). This just enables us to obtain consistent results when rerunning. We will be passing two arguments into our Classifier which are the model created earlier and the learning rate(we will be using 2e-5 in this example). Learning rate is one of the most important hyperparameter in order to allow the gradient to converge, large learning rate can result in unstable training while small rate can result in failure to train. Using pl.Trainer we can select the numbers of GPU to use in the gpus argument but since we are using google colab we are only able to use 1 gpu. We can also select the arithmetic precision from 16-bit,32-bit, and 64-bit, using higher arithmetic precision doesn’t necessary give us higher accuracy, we can achieve just as good accuracy using smaller precision like the 16-bit and with half the memory thus I will be using the 16-bit precision for this example. Finally to wrap it all up, we will pass our classifier, train_loader, and val_loader into trainer.fit.

pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=3)
trainer.fit(classifier, train_loader, val_loader)

The training process should look as below:

Figure 3. Training model

Once the training is done, we would then create a prediction function that takes in an image path then opens the image using Image.open. The opened image is then passed through the feature_extractor to prepare the image by resizing and normalizing and returns a Pytorch tensor and stored as a variable called ‘encoding’. This encoding has a key called ‘pixel_values’ which are the pixel values of the image which will be fed into the model to get a prediction. We then use the logits.softmax(1).argmax(1) to apply a softmax function to the input tensor so it compresses the input so that the sum(input)=1, then using the argmax, we will get the highest probability. However since its a tensor we will convert it into a list and then convert it into a string since the id2label takes in string values.

def prediction(img_path):
im=Image.open(img_path)
encoding = feature_extractor(images=im, return_tensors="pt")
encoding.keys()
pixel_values = encoding['pixel_values'] outputs = model(pixel_values)
result = outputs.logits.softmax(1).argmax(1)
new_result = result.tolist()
for i in new_result:
return(id2label[str(i)])

The image needs to be processed before it can be shown thus, we need to create a process_image function. As always, we would first open the image using Image.open we would then resize the image. If the width is greater than the height then it would create a thumbnail with larger width, or else it would create a thumbnail with larger height. Then we would crop the image by the margins. We would also have to normalize the image by first converting the image to a numpy array. We would then normalize the pixel values so that each pixel value has a value between 0 and 1 be dividing the pixel values by the largest pixel value which is 255. We then use the standard pytorch normalization method for mean and standard deviation which are the same mean and standard deviation for the Imagenet. Using that mean and standard deviation, you’ll want to subtract the means from each color channel, then divide by the standard deviation. Pytorch expects the colour channel to be in the first dimension(C*H*W) however its currently in third dimension(H*W*C) thus we will use np.image.transpose to transpose to image to (C*H*W).

def process_image(image_path):   pil_image = Image.open(image_path)   if pil_image.size[0] > pil_image.size[1]:
pil_image.thumbnail((5000, 256))
else:
pil_image.thumbnail((256, 5000))
left_margin = (pil_image.width-224)/2
bottom_margin = (pil_image.height-224)/2
right_margin = left_margin + 224
top_margin = bottom_margin + 224
pil_image = pil_image.crop((left_margin, bottom_margin,
right_margin, top_margin))
np_image = np.array(pil_image)/255 mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
np_image = (np_image - mean) / std
np_image = np_image.transpose((2, 0, 1))
return np_image

Now to show the image, we would create another function called imshow which takes in the image, the axes as ax and the title as an argument. First we will check if there is an axes provided, if not we would we just create our own figure and axes object(s) using the plt.subplots(). Because matplotlib assumes that the color channel is in the third dimension, we would have to once again transpose the image again and undo the preprocessing. Then we will check if a title was given in the argument, if so it would set the title using the ax.set_title(). We also have to clip the images between 0 and 1 or else it would look like noise when displayed.

def imshow(image, ax=None, title=None):
if ax is None:
fig, ax = plt.subplots()
image = image.transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image + mean
if title is not None:
ax.set_title(title)
image = np.clip(image, 0, 1) ax.imshow(image)

return ax

Finally we can display the image along with its title, and predicted sign using the display_image function(). We then want to find our title by going into our image directory and find the last ‘/’+1 which would return our image file. Say for example our image directory is image_path1= ‘/content/asl_alphabet_test/asl_alphabet_test/O_test.jpg’ then using rfind(‘/’)+1, it finds the last ‘/’ in the image_path1 and takes the next char onwards since I added the (+1). Then we will predict our image and set a label on the x axis using set_xlabel(), finally with , everything will be displayed.

def display_image(image_dir):

plt.figure(figsize = (6,10))
plot_1 = plt.subplot(2,1,1)
image = process_image(image_dir) asl_sign = image_dir[image_dir.rfind('/')+1:]

pred= prediction(image_dir)
plot_1.set_xlabel("The predicted sign: "+pred) imshow(image, plot_1, title=asl_sign);

To show you an example for a single image:

image_path1 = '/content/asl_alphabet_test/asl_alphabet_test/O_test.jpg'
Figure 4. Sample Testing

This is an example using our asl_alphabet_test datasets.

test_data_path = '/content/asl_alphabet_test/asl_alphabet_test'

Since we would like to get all the images from that path we will use glob.glob which returns a list of paths matching a pathname pattern.

images_path=glob.glob(test_data_path+'/*.jpg')

Then we would loop over each image:

for i in images_path:
display_image(i)
Figure 5. More sample testing

This tutorial was created with the help of Nathan Raw and Julien Chaumond’s HuggingPics tutorial on github, Pytorch library, HuggingFace library and the Pytorch lightning library. I would greatly appreciate if any of you could notify me if there are any mistakes or inaccurate information since this is my first time writing on Medium as well as my first time trying out the ViT. My working python code can be access here:

https://github.com/kenjitee/KenjiTee/blob/master/ViT_Tutorial.ipynb

--

--