Class Activation Maps

Anthony Martinez
6 min readJan 18, 2019

--

I will briefly be going over how to modify a pre-trained VGG19 network to output class activation maps (CAM). If you are not aware of CAM then you can see the paper here. The complete code can be seen in this Jupyter notebook here.

To run the sample code you will need:

  • Current version of Pytorch: here
  • Bees and Ants dataset: here
  • ImageNet weights for VGG19: here

To state our use case, we want to see what the network is looking at when making its classification decisions. This leads to first tackling a classification task. For most 2D classification problems a great place to start is utilizing a pre-trained network trained on ImageNet. This is called transfer learning. In this case, we will be using the the hymenopterans dataset (bees and ants) to get a binary image classification using a pre-trained VGG19 network. But, as the title states, our true goal will be to get good enough classification to extract class activation maps.

To jump ahead just a bit, let’s look at what we are after. Below are three images: an input image, its class activation map, and a resized activation map over the original image. From this example, we can see where the network is looking when it is making its classification decision, and that is what we are after. This is why CAM can be useful.

Input image of two ants.
Class activation map of the two ants.
Resized CAM over the original image.

Let’s Begin

Let’s list what we need to accomplish to get CAM working.

  • Use a pre-trained network and freeze most of its weights
  • Modify the network to enable CAM output
  • Train the classifier
  • Use the last convolution to create CAM output
  • Display CAM

VGG19 Modifications

Pytorch’s implementation of VGG is composed of two sequential sections named features and classifier. The features section contains all the convolution layers while the classifier contains the fully connected layers. I’ve modified it for convenience here. Below is the structure of VGG19.

VGG(
(features): Sequential(
(0):Conv2d(3,64,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(1):ReLU(inplace)
(2):Conv2d(64,64,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(3):ReLU(inplace)
(4):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,
ceil_mode=False)
(5):Conv2d(64,128,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(6):ReLU(inplace)
(7):Conv2d(128,128,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(8):ReLU(inplace)
(9):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,
ceil_mode=False)
(10):Conv2d(128,256,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(11):ReLU(inplace)
(12):Conv2d(256,256,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(13):ReLU(inplace)
(14):Conv2d(256,256,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(15):ReLU(inplace)
(16):Conv2d(256,256,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(17):ReLU(inplace)
(18):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,
ceil_mode=False)
(19):Conv2d(256,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(20):ReLU(inplace)
(21):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(22):ReLU(inplace)
(23):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(24):ReLU(inplace)
(25):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(26):ReLU(inplace)
(27):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,
ceil_mode=False)
(28):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(29):ReLU(inplace)
(30):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(31):ReLU(inplace)
(32):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(33):ReLU(inplace)
(34):Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(35):ReLU(inplace)
(36):MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1,
ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)

First, let’s load our pre-trained VGG19 network.

#vgg19 with ImageNet weights 
model = vgg19()
#if you were using the original vgg19
#you would do this instead
model = vgg19(pretrained=True)

Next, let’s freeze our layers.

#freeze layers
for param in model.parameters():
param.requires_grad = False

Since we are going to output two classes, bees and ants, we will need to modify the last convolution from outputting 512 channels to just 2 channels. These two channels will show the activations that are triggered during inference.

#I got better results when I changed the last two convolutions
#which is why you see features[-5]
model.features[-5] = nn.Conv2d(512,512,3, padding=1)
model.features[-3] = nn.Conv2d(512,2,3, padding=1)

According to the CAM paper, fully connected layers destroy valuable localization information. Therefore, let’s discard that section in favor of global average pooling with LogSoftmax. Since I am using Pytorch, I will be using adaptive average pooling.

#remove fully connected layer and replace it with AdaptiveAvePooling
model.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),Flatten(),
nn.LogSoftmax()
)

The bottom of the network should now look like this:

VGG(
...top layers are the same...

(32): Conv2d(512,512,kernel_size=(3,3), stride=(1,1),padding=(1,1))
(33): ReLU(inplace)
(34): Conv2d(512,2,kernel_size=(3,3),stride=(1,1),padding=(1,1))
(35): ReLU(inplace)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): AdaptiveAvgPool2d(output_size=1)
(1): Flatten()
(2): LogSoftmax()
)
)

Now the network is almost ready. Next, we must modify the forward function in the VGG class [code].

class VGG(nn.Module):     
def __init__(self, features, num_classes=1000,
init_weights=True):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
(some more code...see [here] for the full code) def forward(self, x):
x = self.features(x)

#don't flatten here
#x = x.view(x.size(0), -1)
#now when you pass the convolutions to the
#classifier, the data will be in tact and you will be able
#to use global pooling, or adaptiveAvgPool2d
x = self.classifier(x)
return x

That’s it! Next, train the network.

Training

For training I used the following settings:

Loss: Cross Entropy Loss
Learning rate: .0001
Optimizer: Adam
Epochs: 200
Result: Best validation accuracy 95%

To see the training loop, take a look at the notebook [here].

CAM

Now that we have a trained network, lets run inference on an image. But, before we run the image through the model we must first create a small function to save the feature map from the last convolution. Thanks to the nice people at Fastai we have a nice function to do just that:

class SaveFeatures():
features=None
def __init__(self,m): self.hook =
m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output): self.features = output
def remove(self): self.hook.remove()

To use this class, simply pass in the last layer prior to running inference.

#model.features[-3] corresponds to the last convolution
sf = SaveFeatures(model.features[-3])

Since we have SaveFeatures in place, lets get a new sample and run it through our model.

#get image and label
im, lab = next(iter(valid_loader))
im = Variable(im.cuda())
lab = Variable(lab.cuda())
outputs = model(im)
#classes [ant, bee]
res = torch.argmax(outputs.data).cpu().detach().numpy()
if res==1:
print('result: bee',res)
else:
print('result: ant',res)

Lets take a look at the current sample image.

Current input.

Now, let’s grab the filter’s data:

sf.remove()
arr = sf.features.cpu().detach().numpy()
features_data = arr[0]

Since our two classes are [ants, bees], the corresponding labels will be [1,0] for ants and [0,1] for bees. Therefor, the last step to get the heat map will be to get the dot product of the feature map and the label.

ans_ant = np.dot(np.rollaxis(features_data,0,3), [1,0])
ans_bee = np.dot(np.rollaxis(features_data,0,3), [0,1])

The resulting arrays show activations for the ant and bee classes. The first filter, which was not the correct class, shows very weak activations.

Ant activation.

The second filter shows obviously strong activations.

Bee activation.

Finally, you can resize the heat map to the size of the original image and display them together.

Final result.

Conclusion

That’s it. We have modified a pre-trained VGG19 network to solve a binary classification problem and extracted class activation maps. In doing so, we are now privy to what the network is looking at.

--

--