Tutorial: Using Vision Transformer (ViT) to Create a Pokémon Classifier

Jeff
11 min readDec 30, 2021

--

Introduction

Demo of the model deployed on the Ainize Platform. Disclaimer: All images used here were collected from publicly available sources and for demoing purposes only.

This tutorial aims to give a comprehensive walkthrough on training a Vision Transformer (ViT) model for image classification tasks. We’ll do this by first creating a new dataset of Pokémon Images and then using it to fine-tune a pre-trained version of VIT made available by Hugging Face. Once complete, our model will take in ANY Pokémon image and, hopefully, correctly predict its name.

About the Demo

Before getting started, I thought I’d briefly introduce the demo in case you wanted to test out the model yourself. The demo uses a combination of FastAPI and Svelte. It allows you to easily upload an image, drag ‘n drop, or simply paste the image URL to call the model.

Once the FastAPI and Svelte code was written, I deployed it using Ainize. If you’re unfamiliar with Ainize, it allows you to freely deploy dockerized models and create a publicly callable API endpoint. All that is required is a GitHub repository (link to mine) containing a Docker file that wraps your model. I have a tutorial that goes over this process more in-depth if you want to learn more. You may also refer to the official Ainize docs.

What you’ll learn in this tutorial:

  1. Data Collection: I’ll discuss how I collected the Pokémon data for this project and link to the full source code. Feel free to skip if you’re only interested in ViT finetuning.
  2. Data Augmentation: We’ll apply some simple augmentation techniques to expand the number of data points.
  3. About VIT: I’ll give an overview of Vision Transformer.
  4. Finetuning VIT: We’ll finetune a pretrained ViT model on the collected Pokémon dataset using PyTorch-lightning.

Data Collection

Any neural network is only as good as the dataset it’s attempting to model. For that reason, it’s essential to have both high quantity and highly variable data such that the model can learn the correct patterns and successfully interpolate in the real world.

For this project, I’ll primarily be using data collected from two different sources: downloading official sprites made available by Veekun and scraping Brave Search for additional data that offers more variety.

Veekun Images

Veekun offers an easy-to-download collection of official sprites created by Nintendo. Most of these images are in-game sprites from the various Pokémon games. While this serves as a great first choice due to how easy it is to download, it does have some drawbacks. Namely, some of the folders contain duplicates, and the images lack variety (i.e., images were similar in art direction and contained no background).

To address the first issue, I looked through most of the files and discarded any that had notable duplication. (It seemed to mostly occur for games in the same generation, as you’d expect). I attempt to address the variability issue in the next section. You can download the directories I used sourced from Veekun here if interested.

Brave Search Images

To acquire more variable/realistic images, I scraped Brave Search for additional images. If you’re unfamiliar with Brave Search, it’s a new search engine similar to Google with a greater emphasis on privacy. However, any search engine should give good results for data collection similar to this.

For each Pokémon, I made 3 different search queries:

  1. Pokemon {pokemon_name} art
  2. Pokemon {pokemon_name} card
  3. Pokemon {pokemon_name} wallpaper

The reason for these search queries specifically is that they often led to a good variety in images (orientation, art style, etc.) and were less prone to contain the wrong Pokémon than other queries I tried. Additionally, I did a few quality checks to ensure that the data was of decent quality. Namely:

  1. Checked to make sure that the Pokémon’s name was in the image title. Some images in the search results contained a completely different Pokémon (they commonly had the different evolutions of that Pokémon). This is a heuristic to help ensure the Pokémon’s label matched the image.
  2. Discarded duplicated image URLs to help ensure each image is unique.

The complete source code for data collection can be viewed here. In total, I was able to acquire over 55,000 Pokémon images using these sources.

Data augmentation

Data augmentation has traditionally served an important role in image classification models. The goal of augmentation is to expand the number of images in the dataset and give it more variability.

There are many different types of data augmentation techniques commonly used. However, for this particular tutorial, I’ll be experimenting with 6: mirroring, rotating, cropping, quantization, Gaussian noise, and Gaussian blur. Refer to the next section if you’re interested in learning more about these.

1. Mirroring

Mirroring flips the image horizontally. It’s useful for data augmentation as it teaches the model that the object in an image can be in a different orientation and still be the same object.

2. Rotating

Rotating is very similar to mirroring in concept. It rotates the image by a user-specified degree value (random in this case), allowing the model to learn that the object can be in several different orientations. It also has the additional effect of slightly changing the background of the image. Which might be useful for the model.

3. Cropping

Cropping randomly samples a selection of the image and then resizes it back to the resolution of the original image. This often gives a zoom effect.

4. Quantization

A standard image contains 3 color channels: red, green, and blue (this is commonly called “RGB”). Each channel is 8 bits, ranging from 0 (the channel is black or off) to 255 ( the channel is at full intensity). Quantization, in image processing, is a compression method that reduces the number of bits available to each channel. (e.g., if you reduced each channel to 1 bit, it’d only be able to output 1 or 0).

For data augmentation, I’m using it to slightly alter the color scheme of the image.

5. Gaussian Blur

Blurs the image by convolving the image with a Gaussian function. This mostly just means that each pixel gets the weighted average of its neighborhood of surrounding pixels. Giving a blur effect.

6. Gaussian Noise

Adds noise to the image using a Gaussian distribution.

Applying Data Augmentation

To apply the data augmentation techniques described above, I loop through each Pokémon and randomly select an image for augmentation. I then call either 1 or more of the methods utilized above (this process is random). I do this until each Pokémon has a total of 150 images each. This number was selected somewhat arbitrarily. I tried to expand the number of images to 200 images per Pokémon, but it was difficult to train due to memory constraints on my desktop. Still, this increased the number of images in the training dataset to over 137k!

You can check out the complete source code for augmentation here. If others are interested, I may also write a subsequent blog post comparing more augmentation techniques and see which leads to the best results.

Vision Transformer (VIT)

Source: Google AI Blog

Introduction to ViT

Vision Transformer is a relatively new type of image classifying model. It replaces the traditional convolutional neural network (CNN) in vision with a transformer-based architecture. If you’re unfamiliar with Transformers, they have dominated the natural language processing (NLP) landscape for the last couple of years now. Giving rise to powerful models such as GPT-3, BERT, T5, etc. VIT is the first implementation of transformers in vision that surpassed the state-of-the-art (SOTA) in several image classification benchmarks (albeit slightly). It marks the first real sign that transformer-based architectures may come to dominate vision just as they have NLP.

Architecture

At its core, ViT is nothing more than the encoder network of a Transformer. However, it makes a few modifications in preprocessing to make it suitable for computer vision. Let’s go over these.

  1. Patching: ViT begins by splitting an image into a sequence of flattened patches. While there are a few different ways to do this, the pre-trained model I’m using first divides the image into sections of 16x16, non-overlapping pixels. When considering each color channel (RGB), this gives a matrix with a shape of [16,16, 3]. This is then flattened to form a vector of size 768 (16 X 16 X 3).
  2. Trainable Linear Projection: A dense MLP layer without an activation function. It’s used on each flattened patch to reduce dimensionality to the desired embedding size.
  3. Positional embeddings: Positional embeddings in ViT are a learnable parameter added to the patch embeddings to tell the model where in the original image each patch is located.

Once position embeddings are added, ViT works the exact same as any other encoder-based Transformer such as BERT. The only major difference is that this uses the patch embeddings described above rather than the traditional token embeddings.

If you’re unfamiliar with transformers, I’d highly recommend looking over this blog post to learn more about the architecture. Though, in essence, it’s mostly a series of alternating layers of multi-headed self-attention and dense MLPs.

Why it works?

To understand the model more in-depth, let’s first consider how CNNs work. CNNs were designed solely for images and, as such, come with inductive biases (i.e., assumptions made by the architecture of the model to improve performance) that make them suitable for computer vision tasks. Some of these inductive biases include:

  1. Locality: CNNs operate under the assumption that pixels close to one another are more important than those further away. This is why it uses a convolutional operation. It is a local linear operator that is very good at extracting local features such as edges and corners.
  2. Translation Equivariance: Translation equivariance means that if you shift the input to a convolutional layer a specific amount, you’ll also shift the output the same amount. As a simplistic example of translation equivariance, if you had the input [0, 2, 4, 1, 0] and an output of [0, 1, 0], then the pattern of [0, 0, 2, 4, 1] would lead to an output of [0, 0, 1]. In CNNs, this is achieved from weight sharing. Where, as a filter slides through an image, the weights of the filter do not change. And thus, any edge/corner detected by the filter will be detected in different parts of the image as well.
  3. Translation Invariance: Translation invariance is similar to translation equivariance described above. The only difference is that if you were to apply a small shift to the pixels, the output would remain the same. From my understanding, this is a result of the pooling layers within CNNs. In the case of max pooling, it takes the maximum value within a neighborhood of pixels. Meaning that if you were to move around the pixels slightly in that neighborhood, you’d end up with the same result due to it merely taking the max value.

ViT is a much more general type of architecture with a far more minimal set of inductive biases. This is both a pro and a con in some senses. A con to this approach is that it is required to learn aspects about the world that CNNs intrinsically have. In practice, this means that you’re required to pre-train ViT on far more data than a CNN. The major upside is that it can learn more generic patterns that a CNN cannot. For instance, a CNN is restricted to only looking at local features during the early layers of the network. While it eventually is able to observe more global features as you stack multiple convolutional layers, ViT does not have any such limitations and can extract global features from the start.

Fine-tuning ViT

Intro to fine-tuning

Fine-tuning is an increasingly common practice in deep learning. It involves loading in the weights of a pre-trained model and then training it on a new task. This is widely used because it’s very challenging to train large models from scratch due to the compute and data requirements. Fortunately, many who train these models using large servers release the weights of these models for others to use. Using these weights as an initialization point often gives good results because the model has learned some fundamental world knowledge during its pretraining. Allowing it to carry over this knowledge when being fine-tuned for downstream tasks.

For this model specifically, we’ll load in a pretrained model publicly available on HuggingFace. This model was pretrained on ImageNet-21K (14 million images total) and later fine-tuned on the base ImageNet (1 million images). While 15 million images is seemingly a substantial amount, this is actually much lower than the original ViT paper which was trained on a non-released dataset created by Google containing over 300 million images. Though, as you’ll see in the results section, the ImageNet pretraining gives very good results for this task.

Using PyTorch-lightning to finetune ViT

To finetune ViT for Pokémon classification, we must load in the pre-trained model and change the output layer to the total number of Pokémon labels. Additionally, we’ll need to update the config file so the model can associate the output with an ID and the Pokémon’s name.

Aside for these changes, the code follows the standard PyTorch-lightning template that you can read up on here if you’re interested. You may also refer to my source code shown below.

Once the Pokémon LightningModule is defined, we can begin training our Pokémon classifier with the code shown below.

Results

Validation Accuracy

I was pleasantly surprised by the results of the model. The augmented and non-augmented training data did very well on the validation data, with accuracies above 95%. However, I want to stress that the validation data was not intended to be very difficult. It mainly was sourced from Sprites from Veekun images that I did not use during training. (i.e., not exact replicas but similar to instances in the training data). I also added a few hand-picked examples to increase the difficulty somewhat, but these were low in quantity. Even with this easy validation dataset in mind, it seemed to perform well on many of the images I attempted using the frontend demo that I built.

Additionally, I was surprised by the lack of differences between the augmented training dataset and the dataset without any augmentation. Some of this could be chalked up to the validation data I used. Increasing the difficulty of this may widen the gap slightly between the two. On the other hand, this could also indicate that ViT requires less data on downstream tasks than I anticipated. Perhaps this is worth revisiting in a subsequent blog post with more test scenarios.

Using the Model

Using the model in practice is dead simple. All that is required is to point to the directory of the saved, fine-tuned model. Once complete, you can feed in your desired Pokémon image to get the output label. The code below shows a simplistic implementation of this.

You may also refer to the FastAPI code that I wrote for the demo here if you’d like more depth on how to use it in practice. I’ve also made the fine-tune model publicly available on HuggingFace. It could potentially be helpful as a starting point if you’re doing image classification on different types of animated images.

Conclusion

That concludes this tutorial! This tutorial covered everything from data collection to fine-tuning a ViT-based model. Feel free to let me know what you think in the comments.

--

--