Review a series of Semi-Supervised Learning algorithms: MixMatch, ReMixMatch and FixMatch

Guan
工人智慧
Published in
21 min readJul 7, 2020

Introduction

近年來 Semi-Supervised Leaning 大行其道,主要原因為面對越來越複雜的工作,人工標籤數據的成本越來越高,且對於 Supervised Learning 來說,在開放式的場景下蒐集並人工標籤得來的數據集常帶有某種偏誤,使得模型不夠 Robust or Generalized ,更多的 label data 讓訓練的進展有限,甚至傷害模型表現,模型不能夠完全相信 label data 。

而通過存在於數據本身的(取得更低成本的) generalized information ,讓模型自行理解數據,就是 Semi-Supervised Learning 想達到的目標了(Self-Supervised Learning 也是)。

本文三篇皆是關注在透過做出穩定的 pseudo-labeling 以及 consistency regularization 達成上述目標。前者是針對 unlabeled data 做出像是 Supervised Learning 而便於訓練(但 labeling 這件事情可能本身就有問題),後者則是集合原有的 label data 和大量被賦予 pseudo label 的 unlabeled data ,透過有品質的 data augmentation 訓練出媲美使用更大量 label data 的 Supervised Learning 的結果。

更直觀的說,unlabeled data 就算並沒有經過人類賦予新的資訊,它本身帶有的資訊也足夠提昇模型的表現,差別只是在於哪一些方法能夠有效率且有效果的提煉出正確的資訊而已。

MixMatch

Introduction

這篇是接續 mixup: Beyond Empirical Risk Minimization 所延伸出來的方法,mixup 主要探討在模型最佳化的過程中,很容易將不同類別的 convex 變得非常不平滑,在不同類別的交界有非常大的 gradient ,也就是 overfitting 所能觀察到的現象,模型不夠泛化,也更容易受到 adversarial examples 的攻擊。對於了解深度學習的最佳化非常值得一讀。

如先前所提,Semi-Supervised Learning 的本質就是透過 unlabeled data 有效率的在訓練過程中提供正確的資訊,而在 MixMatch 當中,它將希望寄託於 mixup ,主要的原因來自於 mixup 主打能夠使模型更加泛化,是一種 regularization ,為了對抗 label 而存在的 augmentation,換句話說,它能夠 取得 label 以外的有效資訊。

包含 Mixup ,本篇使用了其他 Semi-Supervised Learning 中經典的其他方法:

  1. Data Augmentation for Consistency Regularization
  2. Labeling Guessing
  3. Prediction Sharpening
  4. Mixup

Pseudo code 如下:

MuxMatch Algorithm

對於 unlabeled data 的處理流程:

Diagram of MixMatch Algorithm

Data Augmentation and Labeling Guessing

在 MixMatch 中僅有使用兩種基本的 Data Augmentation , horizontal flip and crops ,對於 label data 是同時使用,而 unlabeled data ,則是分開使用,也就是對於一張圖片,分別透過 flip 和 crop 各產生一張圖片,在 pseudo code 中,就是 K = 2 的意思。

在分別為 label data 和 unlabeled data 做完 augmentation 之後,接著要對 unlabeled data 做出 labeling guessing ,我認為 labeling 這件事情存在很多限制,但是目前最直觀簡便計算 loss 的方式。這邊 pseudo labeling 是將剛剛產生的兩張 filp 和 crop 的圖片輸入模型,得出的 possibility 再做算術平均得出另一個 possibility vector。

Sharpening

由於我們在 unlabeled data 上得到的是一個 possibility vector ,仍需要被決定預測的是哪個標籤,引入一個 sharpening function

Sharpening Function

至此,我們就可以再使用 argmax 得出所猜測的標籤是什麼了。

Mixup

在分別得出了 label 和 unlabeled data 的標籤後,在輸入 loss function 前還必須要再做 mixup ,mixup 之後輸出的標籤才會是送入 loss function 的標籤。

mixup 簡而言之,是將圖片和標籤,與另一張圖片及標籤做線性加權:

十分直覺的作法,一般 data augmentation 僅對圖片進行擾動,但 mixup 認為與其加入 noisy ,不如直接加入其他圖片資訊,由於該圖片的標籤亦包含資訊,同時也使用該標籤對原先標籤做擾動。

在此,我們取原先 label data 和 unlabeled data 組合成一個新的資料集,再將此資料集隨機 shuffle 後產生第二個資料集,對兩個資料集應用 mixup,最終得出的資料集,才被送出計算 loss。

Loss Function

Loss Function of MixMatch

在上述 MixMatch 的演算法中得出的最終 label data 和 unlabeled data 的資料集後, label data 仍然是使用 cross-entropy loss ,而 unlabeled data 特別使用 L2 loss,主要有兩個原因:

  1. Bounded
  2. Less sensitive to completely incorrect predictions

Ablation Study

MixMatch 接下來做了不少 ablation study ,分別證明了

Ablation Study
  1. Distribution averaging is better (K > 1 is better than K = 1)
  2. Prediction sharpening is better ( T <1 is better than T = 1)
  3. Applying Exponential moving average on model parameter during training
  4. Mixup is helpful
  5. Applying unlabeled data into mixup is helpful
  6. Applying label data into mixup is helpful
  7. Mixing label and unlabeled data is helpful

這邊的設計比較都來自於先前 Semi-Supervised Learning 的方法如 Mean Teacher 和 Pi Model,證明在 pseudo-labeling 和 consistency regularization 透過 MixMatch 的都有得到更好的效果。

Conclusion

事實上細節還不少,但之後的 ReMixMatch 能夠達到更好的效果,FixMatch 則化繁為簡,我認為將 MixMatch 作為 prerequisite 即可,不須深入。

ReMixMatch

Introduction

在 MixMatch 的基礎上改進了兩個部份:

  1. Distribution Alignment
  2. Augmentation Anchoring

另外,雖然此時 AutoAugment 已經問世,ReMixMatch 仍提出另一 CTAugement 作為 Data Augmentation 的方法。

此外,不同於 MixMatch 使用算術平均數計算出單一個 guessing label ,ReMixMatch 採用了更複雜的 Loss Function ,使不同類的 unlabeled data 在空間中能夠被分隔的更遠

… argues that unlabeled data should be used to ensure that classes are well-separated. This can be achieved by encouraging the model’s output distribution to have low entropy (i.e., to make “high-confidence” predictions) on unlabeled data. For example, one can explicitly add a loss term to minimize the entropy of the model’s predicted class distribution on unlabeled data

以下我們一項一項介紹:

Distribution Alignment

Distribution Alignment

從 label data 中的各分類比例,將之除以 guessing label 的比例後,所得出的相對比例,將這個相對比例 element-wise 乘上之後所得的 guessing possibility vector,換句話說,就是使用 groundtruth 的比例去加權猜測出來的類別。

舉例來說,如果在 groundtruth 中 A 類佔 90% ,而目前所猜的 A類佔 50%,則該次猜測 A 類的 possibility 都會被乘以 0.9 / 0.5 = 1.8。

直觀的看,當 groundtruth partition >>> guessing partition 時,也就是 guessing label 的比例遠低於 groundtruth 時,就會乘上一個非常大的項次增加猜測該類別的信心,反之則是大幅減少猜測該類別,這將使得 guessing label distribution 往 groundtruth label distribution 靠攏。

Augmentation Anchoring

Augmentation Anchoring

不同於 MixMatch 對 unlabeled data 做 K 次 augmentation 再取算術平均的方式得到 guessing label,再與使用 K 張經過 augmentation 的圖片做 consistency regularization。ReMixMatch 對同一張 unlabeled data 使用了 weak augmentation 和 strong augmentation ,前者直接指定為 guessing label,後者則再與前者做 consistency regularization 。

改為這樣處理的主要原因是,各種 augmentation 的強度和影響是不一致的,使用不同的 augmentation 後再取標籤的算術平均是不合理的,在 MixMatch 中僅僅使用了 flip 和 crop 也可能出自於此,在其 ablation study 中,使用除了 flip 和 crop 之外的 augmentation 會導致模型表現受損,而 ReMixMatch 這樣的方法允許我們帶入像是 AutoAugment 這樣豐富的 augmentation。

CTAugment

AutoAugment 是透過擁有大量標籤的 Supervised Learning 而搜尋到的 augmentation policies ,而這與 Semi-Supervised Learning 的精神相違背,故在此我們不能使用 AutoAugment。

而 CTAugment 選擇一種方法是,不以每一次模型的準確率給出反饋,才選擇適合的 augment policies ,而是在訓練過程中按照 prediction confidence 確認 augment 會不會過於強烈。

實際的演算法是,首先像是 AutoAugment 一樣,維護一個矩陣為 num_augment * num_magnitude,矩陣中的值為每種 augment 在該強度被選取到的機率權重 m_i,並初始化為皆為 m_i = 1 的矩陣。

在訓練的時候首先 uniformly 選取兩個 augment ,對每一個 augment 按照 Categorical(Normalize(m_i)) 後的機率值選取 magnitude,其中

Normalization

而 Categorical(x′) = if x > 0.8, x′ = x, else x′ = 0

將圖片經過這兩種 augment 處理後計算 ω

並按照

更新該 augment權重 m_i,如此反覆更新權重,最後就會得到一群帶有適當機率權重 m_i 的 augmentation。也就是一個矩陣, Column 為某一種 augmentation ,Row 為該 augmentation 的強度等級,矩陣內的值為機率權重 m_i(非機率值,經過 Categorical(Normalize(m_i)) 才是被選取的機率)。

Loss Function

Rotation Loss

和 MixMatch 不同, ReMixMatch 的 Loss Function 中加入了 Rotation Loss:輸入一張經旋轉一定角度(0, 90, 180, 270)的圖片,要求模型輸出圖片的旋轉角度,注意這是一個四個類別的分類問題。

FixMatch

Introduction

FixMatch 與前兩者最大的差別就是移除了 Mixup ,因為事實上僅需要使用 Augment Anchoring 加上 confidence threshold 就能達到很好的效果。

除此之外,Temperature Sharpening 也被 confidence threshold 取代了,故也不需要定義參數 T。

Diagram of FixMatch

整個流程十分簡單,label 和 unlabeled data 皆使用 cross-entropy loss 訓練,而 unlabeled data 的標籤產生不過就是 CTAugment + Augment Anchoring + confidence threshold 而已。

反而花了更多章節討論訓練細節,如 learning rate strategy 和 Optimizer,原因是作者認為 Semi-Supervised Learning 的各種方法並沒有清楚的說明自己使用的架構、training details等等,比起方法論本身,可能對於結果的影響會更大,所以在 ablation study 加入很大篇幅探討這些細節。

下面我找出幾個重點章節詳細討論:

Our Algorithm: FixMatch

Loss of Label Data
Loss of Unlabeled Data

Label data 的訓練很單純的使用 cross-entropy loss ,而 unlabeled data 比起先前的兩個方法也簡單很多:

  1. 使用 weakly-augmented 處理輸入圖片
  2. 將處理後圖片輸入模型取得 possibility vector 之後,直接取 argmax 得到 pseudo label
  3. 使用 strongly-augmented 處理輸入圖片
  4. 將處理後圖片輸入模型取得 possibility vector 之後與 pseudo label 取得 loss function
  5. 僅將 weakly-augmented possibility vector > confidence threshold 部份的 loss 留下

在之後的 ablation study 中會證明,單單僅是靠 confidence threshold 過濾 weakly-augmented 的輸出並 mask 到 loss ,就可以得到非常有品質的 consistency regularization,不需要使用 mixup 。

Barely Supervised Learning

為了測試 FixMatch 的極限,設定一個 CIFAR-10 實驗,label data 僅有各類一張,共 10 張圖片作為 label training data,其餘皆為 unlabeled data。

隨機抽取四次並在每一個資料集上各做四次訓練的結果,也就是總共 16 次訓練,最低達到 48.58% ,最高 85.32%,中位數則是 64.28%。

意外的發現,實驗結果顯示,訓練在同個資料集的四次訓練結果,相對於訓練在不同資料集的訓練結果,前者的差異遠小於後者的差異,也就是 inter-dataset variance >>> intra-dataset variance。

也就是說,在少量的 label data 的訓練中,其品質對於 Semi-Supervised Learning 所帶來的影響也許比我們想像中的要大。

為了驗證這件事情,使用了 Distribution density, tails, and outliers in machine learning: Metrics and applications 的方法,按照 representative 將圖片分為 8 個分類,由 most representative to least representative 。這裡的 least representative 按照該論文,可以解釋為在資料集中的 outlier。

結果是,訓練在第一個分類,也就是 CIFAR-10 中處在整個資料維度中心的圖片,達到了中位數 78% 的正確率,而第八個分類,也就是 CIFAR-10 中的 outlier ,僅有 10% 的正確率。

8 buckets ordering by representative

Ablation Study

這個章節絕對是本文重中之重了,FixMatch 事實上和 Semi-Supervised Learning 許多其他方法很相似,如 Mean-Teacher 和 Pi-Model 等,且相較系列作前兩篇,竟使用到更少的參數及更簡單直覺的演算法,值得在 ablation study 中討論究竟是哪些關鍵導致模型顯著的進步。

Since FixMatch comprises a simple combination of two existing techniques, we perform an extensive ablation study to better understand why it is able to obtain state-of-the-art results.

Ablation Study
  1. Sharpening and Thresholding
    Fig. 3(c) 表明了 Temperature Sharpening 在設定合適的 confidence threshold ( ≥ 0.8 etc.) 之下,並不會影響表現,且 Fig. 3(b) 的結果顯示 confidence threshold 在大於 0.8 時差異並不顯著,可以簡單的設定參數,也比 Temperature 直觀許多。
  2. Augmentation Strategy
    除了談到 CTAugment + Cutout 會有最好效果外,也將 Augment Anchoring 中的 weakly-augmented 和 strongly-augmented 對調,實驗結果表明雖然仍能夠達到好的準確度,但每次訓練的準確度卻不穩定,最準確和最不準確的訓練在準確度上的差異高達 12 % 。
  3. Ratio of Unlabeled Data
    Fig. 3(a) 證明加入越大量的 unlabeled data ,有助於 FixMatch 取得更好的表現。此外,如果線性的隨著 batch size 調整 learning rate 大小,在 FixMatch 上會使訓練更有效率。
  4. Optimizer and Learning Rate Schedule
    Table 5 討論了 optimizer 和 learning rate 以及 momentum 等參數對於模型的表現,毫不意外的 SGD + Momentum 的組合依然可以取得最好的表現,而 Nesterov Momentum 並沒有顯著優勢。
    此外,近年來更喜歡使用 cosine learning rate policy ,在 FixMatch 中,線性下降加上適當的 Weight Decay 也可以達到相似的結果,但似乎也證實了 cosine learning rate 確實是較好的選擇。
    其他更多的組合可以參看論文的 Table 9 。
  5. Weight Decay
    在上圖 Fig. 3(d)中可以看到,最佳與最差的設定可以使表現相差 10 % 以上,但經驗上一般都會使用 0.0001 ,本文也是使用此值。

Conclusion

Semi-Supervised 為人所詬病的一點是僅在 CIFAR 及 SVHN 等小數據集上驗證,對於 COCO 甚至 ImageNet 這樣龐大的標準數據集沒有著墨,更遑論了解應用在真實世界的資料上會有什麼效果了,MixMatch 和 ReMixMatch 都是如此,直到 FixMatch 才直面了 ImageNet。

FixMatch 證實了使用少量高品質的 label data ,以及大量的 unlabeled data 就能夠有非常好的效果。和同為 Semi-Supervised Learning 的 Noisy Student 不同,Noisy Student 使用了 EfficientNet ,並 iterative training ,雖然擊敗 Supervised Learning 成為新的 ImageNet the State-of-art ,但可以想像這樣的訓練成本是非常龐大的。我個人心中,Noisy Student 不但難以復現(世界上有多少人可以 iterative training on EfficientNet-L2?)而且違背了 Semi-Supervised Learning 的精神(標記成本再高,應該也不會高過訓練的成本吧…),FixMatch 會更適合作為實務上的應用。

另外,如果使用 FixMatch 做 iterative training 會如何?也許是另一個經濟實惠的選擇。

下篇也許補上 Noisy Student ,或是另開新坑 Self-Supervised Learning 。

--

--