Video Classification using PyTorch Lightning Flash and the X3D family of models

Dream AI
6 min readAug 23, 2023

--

Author: Rafay Farhan at DreamAI Software (Pvt) Ltd

Small introduction to Video Classification via Lightning Flash

Video Classification is the machine learning task of assigning labels to actions identified in a given video. The main objective is to predict the specific class to which the video clips belong.

An easy, simple, and highly flexible approach to achieving this is by using the Pytorch Lightning Flash API. Built on top of Pytorch LightningAI, it constitutes a collection of tasks for fast prototyping, establishing baselines, and fine-tuning scalable Deep Learning models. Its primary advantage lies in the flexibility it offers. All data loading in Flash is executed via a from_* class method on a DataModule. Lightning DataModules are shareable and reusable objects that encapsulate all data-related code.

Flash enables quick loading of videos and labels from various formats such as config files or folders into DataModules.

Given that the task at hand is Video Classification, the Flash class VideoClassificationData will be employed to create the DataModule. To load the video model used for training, the VideoClassifier class permits access to models and their weights. Both of these classes rely on Pytorch Video.

The Drawback

The Flash docs website currently offers a standard and easy-to-follow Video Classification tutorial for experimentation and inference. When you run the complete example, the process is seamless due to abstract code, facilitating experimentation and inference. However, you may observe that the model used for training and inference in the example is the “x3d_xs” model from the X3D family (For more information, refer to the Pytorch Video Models Zoo).

On the linked X3D page above, you can find larger models than “x3d_xs.” More importantly, these larger models are more accurate, yielding better top-1 and top-5 predictions. This isn’t to discredit the capabilities of x3d_xs; it serves its purpose well, especially for experimentation and inference. Nevertheless, for more extensive datasets and projects, the “x3d_m” model would likely be the better choice. Now, it’s time to put that into action! Upon examining the code in the provided tutorial, the VideoClassifier class is employed to load the model. By simply changing the string provided to the “backbone” argument to “x3d_m,” the new model with its corresponding weights will load. This is satisfactory up to this point.

model = VideoClassifier(
backbone="x3d_m", # The name of the model to be used
labels=datamodule.labels,
pretrained=True,
)
trainer = flash.Trainer(
max_epochs=2,
accelerator="gpu",
devices=1,
)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

However, upon running the training code, an error arises before the first epoch starts, causing the trainer to halt.

RuntimeError: input image (T: 8 H: 8 W: 8) smaller than kernel size (kT: 16 kH: 7 kW: 7)

Although the error message is explicit, resolving it isn’t as straightforward due to the abstraction provided by Flash. Despite scouring the internet, I was unable to find a solution to this problem. Yet, all is not lost. To address this, we need a basic understanding of video models and a deeper exploration of the behind-the-scenes source code of the Lightning Flash library.

A deep dive to find the solution

Now, concerning our problem, delving into the intricate details of how video models or X3D models function, including their architectures, is unnecessary. However, if someone desires to explore further, I highly recommend Christoph Feichtenhofer’s research paper on the X3D model family. Moving forward, the error indicates that our input image is smaller than the kernel size. A video model essentially processes a sequence of images, i.e., a video. When looking at the dimensions, they are represented as (T, H, W). H and W denote the height and width, respectively, which are quite straightforward. But what does T signify? T corresponds to the temporal aspect, indicating how the video evolves over time. As the error message demonstrates, our input images have larger H and W dimensions in comparison to the kernel size, yet the T dimension is smaller. Consequently, this results in the input image being smaller than the kernel size. The predicament is that we haven’t manually adjusted this dimension, nor is it specified as an argument anywhere. So, how and where should we modify this to initiate the training? This is where we’ll delve deeper.

To start, clone the Flash repository onto your local system using Git. Next, open the Lightning Flash folder in your preferred integrated development environment (IDE) — I personally use Visual Studio Code.

Tip: In VSC, “Ctrl + click” on a library or import takes you directly to its source code, facilitating seamless transitions

Once the folder is open, the objective is to access the source code of the class VideoClassificationData. Upon locating it, the first line within the class reads:

input_transform_cls = VideoClassificationInputTransform

Leverage the aforementioned tip to navigate to the VideoClassificationInputTransform class. Within this class, you’ll find the temporal_sub_sample argument, statically set to 8. However, for our specific scenario, this variable needs to be equal to or greater than the “T” dimension of the kernel size, which in our case is 16. Consequently, the solution involves creating a custom transform. Although it might appear overwhelming and slightly scattered at this point, we will now proceed to demonstrate, step by step, how to effortlessly utilize the x3d_m model with just a few adjustments.

Step-by-Step Approach

We will use the Video Classification tutorial provided by Flash as a reference, while incorporating a few adjustments and customization.

Prerequisite Imports

import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
import torch
from flash.video.classification.input_transform import VideoClassificationInputTransform
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample,
UniformCropVideo,
)
from dataclasses import dataclass
from typing import Callable

import torch
from torch import Tensor

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import (
_KORNIA_AVAILABLE,
_PYTORCHVIDEO_AVAILABLE,
requires,
)
from torchvision.transforms import Compose, CenterCrop
from torchvision.transforms import RandomCrop
from torch import nn
import kornia.augmentation as K
from torchvision import transforms as T

Step 1- Create a Custom Transform

This is where the real magic happens. Firstly, we will add a helper function normalize taken directly from the VideoClassificationInputTransform source code.

def normalize(x: Tensor) -> Tensor:
return x / 255.0

Next, we will craft our custom transform class, wherein the sole alteration is setting the temporal_sub_sample to 16. No other modifications are necessary. Consequently, we will replicate the approach used by Flash to create their Transform class. Our Transform class will also inherit from InputTransform, but we will name it “TransformDataModule.”

class TransformDataModule(InputTransform):
image_size: int = 256
temporal_sub_sample: int = 16 # This is the only change in our custom transform
mean: Tensor = torch.tensor([0.45, 0.45, 0.45])
std: Tensor = torch.tensor([0.225, 0.225, 0.225])
data_format: str = "BCTHW"
same_on_frame: bool = False

def per_sample_transform(self) -> Callable:
per_sample_transform = [CenterCrop(self.image_size)]

return Compose(
[
ApplyToKeys(
DataKeys.INPUT,
Compose(
[UniformTemporalSubsample(self.temporal_sub_sample), normalize]
+ per_sample_transform
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)

def train_per_sample_transform(self) -> Callable:
per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)]

return Compose(
[
ApplyToKeys(
DataKeys.INPUT,
Compose(
[UniformTemporalSubsample(self.temporal_sub_sample), normalize]
+ per_sample_transform
),
),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)

def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys(
DataKeys.INPUT,
K.VideoSequential(
K.Normalize(self.mean, self.std),
data_format=self.data_format,
same_on_frame=self.same_on_frame,
),
)

Step 2- Download the Kinetics data

The Flash tutorial page provides a very efficient way to download a subset of the Kinetics data in folders. This step is optional, you can choose to skip this if using your own data

download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")

Step 3- Create the DataModule by adding our custom TransformDataModule()

Following the same procedure outlined in the tutorial, but with a single alteration: this time, we assign our custom Transform to the transform argument.

datamodule = VideoClassificationData.from_folders(
train_folder="./data/kinetics/train",
val_folder="./data/kinetics/val",
clip_sampler="uniform",
clip_duration=1,
decode_audio=False,
transform=TransformDataModule(), # The custom transform is given to the datamodule's transform argument
batch_size=1,
)

Step 4- Load the x3d_m model

Just change the string given to the backbone argument to “x3d_m”

model = VideoClassifier(
backbone="x3d_m", labels=datamodule.labels, pretrained=True
) # The backbone is changed to the name of the model desired

Step 5- Create the trainer and start the Training process

trainer = flash.Trainer(
max_epochs=2,
accelerator="gpu",
devices=1,
)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

Inference

That’s it — your new x3d_m model has been trained. If you intend to perform inference, once again, simply include your custom transform within the prediction DataModule.

datamodule_p = VideoClassificationData.from_folders(
predict_folder="data/kinetics/predict",
transform=TransformDataModule(), # The custom transform is given to the datamodule's transform argument
batch_size=1,
)
predictions = trainer.predict(model, datamodule=datamodule_p, output="labels")
print(predictions)
[['archery'], ['high_jump'], ['flying_kite'], ['marching'], ['high_jump']]

--

--