Swin-Transformer

gary.TsAI(Taiwan A.I.)
9 min readAug 7, 2022

--

2020 年,將純粹的 Transformer 引入 CV 領域的論文引起了各路好手的廣泛關注,Vision-Transformer 的一系列研究工作大多數是對標準 Transformer 進行了增強。2021 年微軟研究人員發表了 Swin-Transformer,這可以說是原始 Vision-Transformer 後最令人興奮的研究之一。Swin-Transformer 不光是應用範圍廣,效果更是炸裂,看名字就知道是個基於 Transformer 模型,在視覺任務中取得了非常先進的 Performance,Swin-Transformer 目前被廣泛作為模型的骨幹(Backbone)。

首先,我們談談 Swin-T 跟 Vi-T 不同

如圖 1 所示,左圖是今天要談的 Swin-Transformer,右圖是 Vision-Transformer。大致上論文的貢獻為以下兩點:

  1. Vision-Transformer 對原始圖像下採樣 16 倍,後面的特徵圖也是維持這個下採樣率不變。Swin-Transformer 對原始圖像下採樣 4 倍,8 倍以及 16 倍,使用了類似卷積神經網絡(Convolutional Neural Network)中的層次化構建方法(Hierarchical Feature Maps)。一般 CNN 下採樣會使用池化層(Pooling layer),將輸入的圖片尺寸縮小,從而能減少模型參數數量,加快系統運作的效率,有利於減少模型 Over Fitting 問題。Swin-Transformer 提出 Patch Merging,用以達到跟池化層類似的下採樣操作。
  2. Vision-Transformer 中直接對整個特徵圖進行 Multi-Head Self-AttentionSwin-Transformer 提出 Windows Multi-Head Self-Attention(W-MSA),大幅減少計算量,將特徵圖劃分成了多個不相交的窗口(Window),並且 Multi-Head Self-Attention 只在每個窗口內進行。但問題就來了,單個窗口做 Multi-Head Self-Attention(MSA),彼此窗口之間不就無法傳遞訊息呢?所以 Swin-Transformer 提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA),藉由此方法讓訊息在彼此窗口之間進行傳遞。
圖 1 :左圖是 Swin-Transformer,右圖是 Vision-Transformer

Swin-Transformer 模型架構

這邊詳細介紹 Swin-Transformer 的模型架構 ,如圖 2 所示。左側為整體架構,右側為 Swin-Transformer Blocks 更加詳細的架構。

首先將一張圖片按給定大小分成一堆 Patches,將輸入圖片 [H, W, 3] 按照 4x4 大小的 Patch 進行劃分,劃分後會得到 [H/4, W/4, 48] 個 Patches。因為每個 Patches 裡有 4x4=16 個像素,然後每個像素有 R、G、B 三個值所以展平後是 16x3=48。接著通過 Embedding 層將每個 Patch 的 Channel 數據做線性映射,由 48 變成 C。這樣每個 Patche 數據 Shape 從 [H/4, W/4, 48] 變為 [H/4, W/4, C] 。Source Code 中 Patch Partition 和 Linear Embeding 就是直接通過一個卷積層實現的,和之前 Vision-Transformer 中講的 Embedding 層結構一模一樣。

接下來通過 4 個 Stage 構建不同大小的特徵圖,除了 Stage1 中先通過一個 Linear Embeding 層外,剩下三個 Stage 都是先通過一個 Patch Merging 層進行下採樣。然後重複堆疊 Swin-Transformer Block。這裡的 Block 有兩種結構,兩個結構是成對使用的,先使用一個 W-MSA 結構再使用一個 SW-MSA 結構。所以你會發現堆疊 Swin-Transformer Block 的次數都是偶數,所以才有了 Swin 這個名字(?)。最後對於分類網絡,後面還會接上一個 Layer Norm 層、全局池化層以及全連接層得到最終輸出。圖中沒有畫,但 Source Code 中是這樣做的。

圖 2:左側為整體架構,右側為 Swin-Transformer Blocks 詳細的架構

Patch Merging

Swin-Transformer 提出 Patch Merging,用以達到跟池化層類似的下採樣操作。如圖 3 所示,假設輸入 Patch Merging 的是一個 [4, 4, 1] 的單通道特徵圖,Patch Merging 會將每個 2x2 的相鄰像素劃分為一個 Patch,然後將每個 Patch 中相同位置(同一顏色)像素給拼在一起就得到了 4 個特徵圖。接著將這 4 個特徵圖在深度方向進行 Concat 拼接,接著通過一個 LayerNorm 層。最後通過一個全連接層在特徵圖的深度方向做線性變化,將特徵圖的 Channel 由 C 變成 Cx2。這個例子可以看出,通過 Patch Merging 層後,特徵圖的 H 和 W 會減半,C 會翻倍。

圖 3:Patch Merging 的操作方式

Windows Multi-Head Self-Attention(W-MSA)

Swin-Transformer 提出 Windows Multi-Head Self-Attention(W-MSA),大幅減少計算量,將特徵圖劃分成了多個不相交的窗口(Windows),單獨對每個窗口內部進行 MSA。並且 MSA 只在每個窗口內進行。如圖 4 所示,左圖紅框為 Swin-Transformer 單獨對每個窗口(16 個 Patch)內部進行 W-MSA,右圖紅框為 Vision-Transformers 對整張特徵圖進行 MSA。

圖 4:左圖紅框為 Swin-Transformer 單獨對每個窗口內部進行 W-MSA,右圖紅框為 Vision-Transformer 對整張特徵圖進行 MSA

Shifted Windows Multi-Head Self-Attention(SW-MSA)

Swin-Transformer 提出 Shifted Windows Multi-Head Self-Attention(SW-MSA),解決因為單個窗口做 MSA,彼此窗口之間無法傳遞訊息的問題,進行偏移的 W-MSA。問題就來了,原本的窗口是 4 個,偏移之後變成了 9 個,增加了計算複雜度。如圖 5 所示,左圖(第 Layer 1 層)使用的是 W-MSA,右圖(第 Layer 1+1 層)使用的就是 SW-MSA,因為這兩個結構是成對使用的。

圖 5:左側為 W-MSA,右圖為 SW-MAS

Efficient batch computation for SW-MSA

如圖 6 所示,為了提高 SW-MSA 的計算效率,作者把上述(第 Layer 1+1)層改為 cyclic shiif,將左上角淺黃色 A 移至右下角的 A,左邊淺藍色 B 移至右邊的 B,上方淺綠色 C 移至下方的 C。進而讓窗口維持 4 個。

那問題又來了,同一個窗口可能出現 Patches 亂兜,Patches 是從別的地方給剪貼過來的。就好比原本的左下角的窗口,原本的天空被放到跟地面一起去計算 MSA。

圖 6:Efficient batch computation for SW-MSA

Masked MSA

Swin-Transformer 提出了 Masked MSA 解決同一個窗口可能出現 Patches 亂兜問題。如圖 7 所示,為了方便大家理解,我自己又重新畫了一次移動窗口,並對每個窗口加上了一個標籤。圖左是原先移動後的窗口;圖中是將標籤 1、2、3 從上搬到下;圖右是將標籤 4、7、1 從左搬到右,這樣窗口數量就會從原本圖左的 9 個降低為圖右的 4 個,就能對單個藍色窗口進行 MSA。

最後就來介紹一下作者精細設計的四個 mask(我將之翻譯為遮罩) 吧!

圖 7:加入標籤的 SW-MSA

如圖 8 所示,(a) 將左下角窗口沿著左向右、上向下(黃線),依序把 Patches 排列出 (b) 來 ,因此標籤為 8 的 Patches 共有 28 個,標籤為 2 的 Patches 共有 21 個。(c) 將排列後的矩陣與其反置矩陣進行 Dot-Product,產生 (d) 的 Attention-Matrix,左上角生成多個 α88, 右下角生成的多個 α22,這兩個部分是需要保留的,視為相同標籤之間 MSA 運算。右上角生成多個 α82,左下角生成多個 α28 ,視為不同個標籤之間 MSA 運算,因此將它都減去 100(α 的值都很小,一般都是零點幾的數字,將其減去 100 後通過 SoftMax 得到的權重就會等於 0 了)

圖 8: Masked MSA 設計

最後給大家看一下作者在 GitHub 的回覆,他將 Mask 給可視化出來,有興趣的小夥伴可以去跑看看他提供的程式碼。如圖 9 所示,左圖為做完 reverse cyclic shift 的結果,右圖為他替這四種窗口設計的 Mask。

圖 9:Attention Mask

--

--