Feature Extraction Using CNNs via PyTorch

In this article, we will explore CNN feature extraction using a popular deep learning library PyTorch. We will go over what is feature extraction, why is it useful, and a code implementation. Although there are a lot of great in depth articles on feature extraction, in my experience they are often very lengthy, therefore my hope with this article is to provide a short yet helpful method for feature extraction using PyTorch.

Rabia Gondur
6 min readDec 6, 2023
Photo by Clint Adair on Unsplash

Feature Extraction

Feature extraction is an important method in machine learning and computer vision where it is applied to data, e.g. images, to extract the salient features from the data. I am sure you have seen a lot of articles using pre-trained large image models such as AlexNet or ResNet to extract useful features from images. But why do we think that they would have good representations? Well, these models are trained on large amounts of data to classify a given image to specific class. In order to do this, the architecture of the network needs to leverage some of the crucial features in the data to increase the predictive performance. Hence why whenever you read about feature extraction, they always use up until the fully connected layers at the end because these layers are usually task specific so they are not that versatile. But where did we come up with this idea of convolutional neural networks being capable of capturing features? Convolutional Neural Networks (CNNs) are infamous for having a hierarchical processing of the data, where the earlier layers describe the very basic, low level features and the complexity increases as you move down the layers.

This characteristics of CNNs somewhat mimic the hierarchy of human visual processing, thus they are also great in silico methods to study some of the visual processing patterns in humans. I will later have a separate article on CNNs that will go more in depth to the specifics but for the purposes of this article, here are some key components of CNNs:

  1. Convolutional Layers: They are important for capturing local patterns in the input data through the usage of filters to scan through the data. This enables the network to identify features like edges, textures, etc. hence allowing the model to focus on important local information.
  2. Pooling: Pooling helps in downsampling the spatial dimensions of the feature maps (output of a convolutional layer). Hence why it plays a big role in translation invariance of CNNs and reduces the computational load.
  3. Activation functions: This is the same activation function that we always see whenever we are dealing with neural networks. They introduce non-linearities to the model, thus enabling it to learn complex relationships in the data.
  4. Fully Connected Layers: Fully connected layers enable the network to consolidate learned features and make high-level abstractions. They help the network understand global patterns and relationships in the data. And they are usually the most task specific part of the model as they are often in the final layer right before classification.

Now that we have a general idea of CNNs, let’s look at a specific example of a CNN architecture which is ResNet50.

Hands-on Example: Feature Extraction using pre-trained ResNet50

I want to preface that there isn’t only one way to extract features in neural networks; you can use various methods depending on how much you want to modify the original network or the tools provided by the deep learning library you use. Thus, it is up to you to choose the most convenient and practical method for your use case. However, this method worked for me, so I hope it can be helpful for you!

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

First we start with importing the necessary libraries. Here I have some extra libraries that I used in my full code so you can trim this to your liking or add other libraries for further analysis. In this example I also omit data handling as it is specific to your task. However if you would like to try this, you can use some of the datasets that are already available in PyTorch such as CIFAR10. Now let’s get to loading the model.

resnet50 = models.resnet50(pretrained=True)
resnet50.eval()

This process is pretty straightforward as we just use models from torchvision. To get a sense of what the model looks like, we loop over the children of the resnet50 class.

for i, x in resnet50.named_children():
print(f'NAME: \n {i}')
print(f'CONTENT: \n {x}')

This will give us something like this image:

Image 1: A snippet of ResNet50 architecture

Here we see the name of the layer and the actual operation that goes into it. For example in conv1, we apply a conv2d to our data and then in bn1, we do batch normalization and so on. This looks pretty simple, as we can just pick a specific layer with the correct name and use it as our feature extractor. However, this is only the first few layers of ResNet50, as we go down, we have something more like this:

Image 2: A snippet of ResNet50 architecture (later layers)

As you can image now we can’t just use the name layer1 if we want to get access to a specific part in layer1 because they are wrapped in sequential and bottleneck layers. So if we want to access the convolution operation that is highlighted by a red rectangle, we can’t just use the name layer1 to get that as there are also other convolutions within layer1. To overcome this challenge we can use something called hooks.

Image 3: Image 2 with a specific layer highlighted

Below is an example function that you can find in PyTorch where we define a function that can extract the activations from a specific layer.

embeds = {}
def get_activation(name):
def hook(model, input, output):
embeds[name] = output
return hook

After we define a function to extract the activation of a given layer, let’s assume that we are interested in conv1, conv2, and conv3 layers within these sequential and bottleneck blocks. When you check the full summary of ResNet50, you will see that these layers that have the sequential blocks are usually called ‘layer1’ etc. so in the code below, I specifically go through these children so that I can make another loop to go through the bottleneck blocks. Once I do that, I can register the forward hook from PyTorch to given layer such as conv1 from the bottleneck, and use the function above to extract the activation. Since this is a dictionary, it will create the name conv1_name_i depending on the name and the index as the key of the dictionary, and store its activations as the value to its corresponding key. This way we can use the key to get specific activations. After all that, we remove the hooks so that they don’t interfere with the normal functioning of the model, although I don’t know if that’s very necessary in this specific case.

for name, layer in resnet50.named_children():
if name in ['layer1', 'layer2', 'layer3', 'layer4']:
for i, bottleneck in enumerate(layer.children()):
hook1 = bottleneck.conv1.register_forward_hook(get_activation(f'conv1_{name}_{i}'))
hook2 = bottleneck.conv2.register_forward_hook(get_activation(f'conv2_{name}_{i}'))
hook3 = bottleneck.conv3.register_forward_hook(get_activation(f'conv3_{name}_{i}'))

for hook in [hook1, hook2, hook3]:
hook.remove()

After this, you can pass your input through the model as you normally would and that’s it! You now can get the intermediate activations and extract features!

To see the full example with a toy dataset (CIFAR10) see this GitHub repo !

--

--

Rabia Gondur

ML researcher | M.S. in Data Science | B.S. in Integrative Neuroscience