Vision Transformer Embrace Convolutional Neural Networks | TEC-Net
A Hybrid CNN-Transformer Architecture for Medical Image Segmentation with DDConv and SW-ACAM, compatible with quantitative and qualitative analysis
The convolutional neural networks are limited to capturing global features, whereas, transformers are limited to extracting local features. Therefore, the hybrid architecture of transformers and convolutional neural networks has been recently used to overcome the shortcomings of both CNNs and transformers.
Dynamic Deformable Convolution (DDConv) is employed in the CNN branch to adapt to the spatial variations in the input image by learning a set of offset maps and a (Shifted)-Window Adaptive Complementary Attention Module ((S)W-ACAM) and compact convolutional projection are also applied to transformers branch to learn the cross-dimensional long-range dependency of images.
Contents
- Overview
- TEC-Net Architecture
- Experimental Results
Overview
The hybrid architecture of CNN and Transformer captures both local and global features of images by taking advantage of CNN and Transformer. However, hybrid architectures have two limitations.
- Local feature modeling: Vanilla convolution is good at modeling local features, but it may not be able to capture the deformation and irregularities of organs and lesions. To address this issue, Dynamic Deformable Convolution (DDConv) is proposed to adapt to the spatial variations in the input image, which can help to improve the accuracy of local feature extraction.
- Global feature modeling: Transformer is good at modeling global features, but it may not be able to capture the correlation between spatial and channels. To address this issue, a Shifted-Window Adaptive Complementary Attention Module (SW-ACAM) is proposed to focus the attention of the transformer on different parts of the image, and it is also able to capture the correlation between spatial and channels.
TEC-Net Architecture
To optimize the utilization of both local and global features of medical images, a parallel interactive network architecture called TEC-Net has been developed. The overall architecture of TEC-Net is shown in Fig. 1(a).
TEC-Net continuously feeds the local details extracted by the CNN branch to the decoder of the Transformer branch. Similarly, TEC-Net also feeds the global long-range features captured by the Transformer branch to the decoder of the CNN branch.
TEC-Net consists of four components: a patch embedding model, a dynamically adaptive CNN branch, a cross-dimensional fusion Transformer branch, and a feature fusion module.
the dynamically adaptive CNN branch and the cross-dimensional fusion Transformer branch follow the architecture of U-Net and Swin-Unet.
The dynamically adaptive CNN branch consists of seven key steps and utilizes the adaptive DDConv technique, which adjusts the weight coefficient and deformation offset of the convolution at each step. This improves the segmentation accuracy of irregular objects in medical images.
Likewise, the cross-dimensional fusion Transformer branch also consists of seven steps. In each step, the (S)W-ACAM attention technique is employed, as depicted in Figure 1(b). This allows the segmentation network to better comprehend the global dependencies in medical images, enabling it to capture position information between various organs and improve the separability of segmented objects from the background in medical images.
The TEC-Net improves both local and global feature representation and exhibits promising potential for medical image segmentation. This is superior to using pure CNN or Transformer networks.
To reduce the number of parameters and computational costs in the transformer, instead of MLP a lightweight perceptron module (LPM) is employed (inspired by GhostNet).
Dynamic Deformable Convolution (DDConv)
The vanilla convolutions are inherently limited in their ability to adapt and generalize to complex medical image feature representations. To address this issue, Dynamic Deformable Convolution (DDConv) is employed.
As depicted in Figure 2, DDConv can adaptively learn the kernel deformation offset and weight coefficients based on the data distribution, so it allows the network to realize the change in shapes and values of convolution kernels. As a result, it enables the network to effectively capture changes in the shapes and values of convolution kernels, resulting in improved feature representation and generalization performance.
The output feature map of vanilla convolutions is illustrated below, where σ is the activation function, W is the convolutional kernel weight matrix and y is the output feature map.
Then, the output of the feature map of DDConv is shown below, where n is the number of weight coefficients, αn is the weight coefficients with learnable parameters and yˆ is the output feature map generated by DDConv.
DDConv dynamically adjusts the weights of the convolution kernels by combining different weight matrices based on the corresponding weight coefficient before applying the convolution operation.
Shifted Window Adaptive Complementary Attention Module (SWACAM)
The self-attention mechanism in the original vision transformer only considers the dependency in the spatial dimension but not the cross-dimensional dependency between spatial and channels. To overcome this limitation, a cross-dimensional self-attention module called SW-ACAM is proposed.
(S)W-ACAM has the advantages of spatial and channel attention, and can also capture long-distance correlation features between spatial and channels. This allows the transformer to learn more complex features than traditional self-attention mechanisms.
As shown in Figure. 3, (S)W-ACAM consists of four parallel branches, the top two branches are the conventional dual attention module and the bottom two branches are cross-dimensional attention modules. These four branches complement each other and provide richer long-range dependency relationships.
Among the four parallel branches of (S)W-ACAM, two branches are used to capture channel correlation and spatial correlation, respectively, and the remaining two branches are used to capture the correlation between channel dimension C and space dimension H and vice versa (between channel dimension C and space dimension W).
To reduce the computational cost of the network, the shifted window and compact convolutional projection techniques are used. The shifted window operation only calculates the self-attention in the local window, which significantly reduces the spatial resolution of the inputs. The compact convolutional projection operation compresses the channel dimension of the feature maps, further reducing the computational cost of the network.
Experimental Results
Architecture variants
- TEC-Net-T: layer number {2, 2, 6, 2, 6, 2, 2}, H {3, 6, 12, 24, 12, 6, 3}, D 96
- TEC-Net-B: layer number {2, 2, 18, 2, 18, 2, 2}, H {4, 8, 16, 32, 16, 8, 4},
D 96
where D represents the number of image channels when entering the first layer of the dynamically adaptive CNN branch and the cross-dimensional fusion Transformer branch, layer number represents the number of Transformer blocks used in each stage, and H represents the number of multiple heads in the Transformer branch.
Implementation Details and Evaluation
Datasets: skin lesion segmentation dataset (ISIC2018), Liver Tumor Segmentation Challenge dataset (LiTS), Automated Cardiac Diagnosis Challenge dataset (ACDC)
Configuration: Adam with an initial learning rate of 0.001m mean squared error loss (MSE) and Dice loss as loss functions
The Quantitative Evaluations
Table 1. shows the quantitative analysis results of the proposed TEC-Net and the SOTA CNN and Transformer based networks on the ISIC2018 dataset. The proposed TEC-NetT requires only 11.58 M of parameters and 4.53 GFLOPs of computational costs but still achieves the second-best segmentation effect. TEC-Net-B, BAT, CvT, and CrossForm have similar parameters or computational costs, but on the ISIC2018 dataset, the division Dice value of our TEC-Net-B is 1.02%, 3.00%, and 3.79% higher than that of the BAT, CvT, and CrossForm network respectively.
Table 2. shows the quantitative analysis results of the proposed TEC-Net and the competitive networks on the LiTSLiver dataset. The suggested TEC-Net-B and TEC-Net-T networks achieve good results in medical image segmentation in the first and second place, with the least number of model parameters and computational costs. The division Dice value of TECNet-B without pre-training is 1.20%, 1.03%, and 1.01% higher than that of the Swin-Unet, TransUNet, and CvT network with pre-training.
Table 3. shows the quantitative analysis results of the proposed TEC-Net and the competitive networks on the ACDC dataset. The proposed TEC-Net still shows significant advantages on MRI-type multi-organ segmentation datasets. Both TECNet-T and TEC-Net-B provide SOTA segmentation effects for the left ventricle (LV), right ventricle (RV), and left ventricular myocardium (MYO).
The Qualitative Evaluation
As depicted in Figure 4. a, the CNN branch captures more accurate local detail information of segmented targets. because it mainly captures local features, it is more vulnerable to noise interference. As shown in Figure 4. b, the Transformer branch captures more accurate position information of segmented targets. However, because it captures global features, it can miss the fine-grained details of the segmented target.
So, after integrating the CNN branch with the Transformer branch (shown in Figure 4. c), the TEC-Net network fully inherits the structural and generalization advantages of CNN and Transformer, providing better local and global feature representations for medical images, demonstrating great potential in the field of medical image segmentation.
References
[1] Tao Lei, Rui Sun, Y. Wan, Yong Xia, Xiaogang Du, A. Nandi, TEC-Net: Vision Transformer Embrace Convolutional Neural Networks for Medical Image Segmentation (7 June 2023), IEEE Transactions on Medical Imaging