Squeeze and Excitation Implementation in TensorFlow and PyTorch — Idiot Developer

Nikhil Tomar
Analytics Vidhya
Published in
2 min readDec 1, 2021

The Squeeze and Excitation network is a channel-wise attention mechanism that is used to improve the overall performance of the network. In today’s article, we are going to implement the Squeeze and Excitation module in TensorFlow and PyTorch.

What is Squeeze and Excitation Network?

The squeeze and excitation attention mechanism was introduced in the year 2018 by Hu et al. in their paper “ Squeeze-and-Excitation Networks “ at CVPR 2018 with a journal version in TPAMI. It is one of the most dominant papers in the field of attention mechanisms and was cited more than 8000 times.

The Squeeze and Excitation Network basically introduces a novel channel-wise attention mechanism for CNNs (Convolutional Neural Network) to improve their channel interdependencies. The network adds a parameter that re-weights each channel accordingly so that it becomes more sensitive towards significant features while ignoring the irrelevant features.

Squeeze and Excitation Network is a channel-wise attention mechanism that recalibrates each channel accordingly to create a more robust representation by enhancing the important features.

READ MORE: Squeeze and Excitation Networks

The block diagram of the Squeeze and Excitation Network.

Squeeze and Excitation Implementation in TensorFlow

from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Input def SqueezeAndExcitation(inputs, ratio=8): 
b, _, _, c = inputs.shape
x = GlobalAveragePooling2D()(inputs)
x = Dense(c//ratio, activation="relu", use_bias=False)(x)
x = Dense(c, activation="sigmoid", use_bias=False)(x)
x = inputs * x
return x
if __name__ == "__main__":
inputs = Input(shape=(128, 128, 32))
y = SqueezeAndExcitation(inputs)
print(y.shape)

Squeeze and Excitation Implementation in PyTorch

import torch 
import torch.nn as nn
class SqueezeAndExcitation(nn.Module):
def __init__(self, channel, ratio=8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.network = nn.Sequential(
nn.Linear(channel, channel//ratio, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel//ratio, channel, bias=False),
nn.Sigmoid()
)
def forward(self, inputs):
b, c, _, _ = inputs.shape
x = self.avg_pool(inputs)
x = x.view(b, c)
x = self.network(x)
x = x.view(b, c, 1, 1)
x = inputs * x
return x
if __name__ == "__main__":
inputs = torch.randn((8, 32, 128, 128))
se = SqueezeAndExcitation(32, ratio=8)
y = se(inputs)
print(y.shape)

Conclusion

In this coding tutorial, you have learned about one of the most widely used channel-wise attention mechanisms known as “Squeeze and Excitation Network”.

Still, have some questions or queries? Just comment below. For more updates. Follow me.

Originally published at https://idiotdeveloper.com on December 1, 2021.

--

--