論文閱讀 NeurIPS 2020 — FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

該論文巧妙地結合了 Pseudo-labeling 和 Consistency regularization 等 Semi-supervised learning 技巧,提出新穎的演算法稱為 FixMatch。而整合的過程有對演算法進行簡化,不僅在多個資料集上達到 SOTA 表現,在資料量極少的情況下,準確度仍相當驚人。

Ken Huang
人工智慧,倒底有多智慧?
17 min readJun 19, 2021

--

論文連結 :
《 FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence 》

作者與出處:

Sohn, K., Berthelot, D., Carlini, N., Zhang, Z., Zhang, H., Raffel, C., Cubuk, E., Kurakin, A., & Li, C.L. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. In Advances in Neural Information Processing Systems (pp. 596–608). Curran Associates, Inc..

GitHub Link

Introduction

深度神經網路已成為當今電腦視覺實際上應用的模型,這類技術之所以成功的部分原因在於它的可擴展性,通常只要在越大的資料集上進行訓練,就會得到越好的效果。

但這過程在目前的做法中,大多是以 Supervised learning 的做法執行,這需要仰賴大量的人力對資料進行標註,成本非常可觀,特別是應用領域需要專業知識判斷時,這個現象就更明顯 ( Ex: 醫學相關應用可能就需要醫生來進行標註 )。

Semi-supervised learning ( SSL )

而 Semi-supervised learning 是另一種善用 Unlabeled data 緩解標注資料需求的做法,由於 Unlabeled data 較容易取得,使該作法可以用較低的成本提升最終的成效。

而常見的作法有以下 2 種 :

  1. Pseudo-labeling ( Self-training ):讓模型對 Unlabeled data 做預測,再以這些預測結果作為標註 ( Pseudo-label ) 進行訓練。
  2. Consistency regularization:與 Pseudo-labeling 的流程類似,但會對輸入資訊隨機地做一些擾動。

近期在這 2 個方向上的研究都有複雜化的趨勢,而該論文將上述 2 種技巧混合使用,並提出一個更簡單、準確的演算法,稱為 FixMatch

Sohn, K., Berthelot, D., Carlini, N., Zhang, Z., Zhang, H., Raffel, C., Cubuk, E., Kurakin, A., & Li, C.L. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. In Advances in Neural Information Processing Systems (pp. 596–608). Curran Associates, Inc. Figure 1.

該算法會對輸入影像同時做 2 種 Data augmentation:

  1. Weak:做法上是用翻轉、位移等操作執行,其結果用來生成 Pseudo-label ,並作為另一種 Augmentation 的 Target
  2. Strong:受到先前研究啟發,這邊使用的是 Cutout、CTAugment、RandAugment 這種比較複雜的 Augmentation 方法,通常會讓影像失真的程度較高

在生成 Pseudo-label 的過程也參照先前研究的作法,信心程度的數值必須高過一定的程度,才會當作後續的訓練資料。

根據論文實驗來看,FixMatch 超越了過去的 SOTA ,甚至在每個類別只有 4 筆資料的情況,還能有 88.61% 的準確度 ( 在 CIFAR-10 資料集上 ),詳細的實驗數據會在 Experiments 章節呈現。

FixMatch

在描述 FixMatch 之前,先來 Review 一下 Pseudo-labeling 和 Consistency regularization。

Background

Consistency regularization 的核心思想是:一個模型面對一張影像的不同 Augmentation 結果要有相似的預測,這使得相關研究在對 Unlabeled data 定義 Loss 的時候比較特別:

α 和 pm 是 stochastic function,會需要設定兩個不同的值。µB 指的是一個 Batch 的 Unlabeled samples。

有些延伸作法會把 α 替換成 Adversarial transformation;或是同時使用過去幾個版本的模型做預測後取平均,再放入 pm;也有把平方的 L2 Loss 換成 Cross-entropy 的作法;甚至使用強度更大的 Augementation 。

而 Pseudo-label 的核心思想是:讓模型自己對 Unlabeled data 生成 Label ,再拿來訓練自己。Loss function 定義如下:

µB 涵義同上。qb = pm( y | ub ),是指模型對 Unlabeled data 做推論的機率分布。τ 是個 Threshold ,只有當某個類別的信心指數夠高時,才會將其作為訓練資料。最右邊的 H( ) 則是指說,把兩組機率拿來計算 Cross-entropy。

整體來看,如果模型對這筆 Unlabeled data 夠有信心 ( max(qb) ≥ τ ),那就會透過最小化 Entropy 的方式來進行學習 ( H( qb , qb) )。

FixMatch

在 FixMatch 演算法中,包含了 2 個計算 Cross-entropy 的地方,其一是對 Labeled data 做 Weak augmentation 之後,執行 Supervised learning 的部分, Loss 定義為 ℓs:

這邊就是個標準的 Cross-entropy

另一個比較特殊的部分是對 Unlabeled data 的處理,Loss 定義為 ℓu:

如同前一章節的流程圖,這邊會先檢核模型對 Weak augmentation 的預測結果中,機率最大值是否超過閥值 τ,在超過的情況下會被轉換為 Pseudo-label,再與 Strong augmentation 的推論結果計算 Cross-entropy 作為 Loss。

而這 2 種 Loss 結合的時候還會透過 λu 控制 Unlabeled data 的影響程度:

而在近期一系列關於 SSL 的研究中有個趨勢,這個 λu 會根據訓練時間調整影響的程度 ( 隨著訓練時間越長,λu 會越大 )。但這個設定在 FixMatch 卻不起作用,作者們推論原因是訓練初期的 max( qb ) 大多比 τ 小所導致的。

Augmentation in FixMatch

前面簡介過 FixMatch 使用的 Data augmentation 有分 Weak 和 Strong。

Weak 部分是常見的翻轉、偏移等操作,參數上的設定是以 50% 的機率對影像做水平翻轉,再透過 12.5% 的幅度作為垂直水平偏移量的上限。

Strong 的部分就複雜許多,該論文使用了 2 種 AutoAugment 的變種方法:RandAugment 和 CTAugment。

RandAugment 不同於 AutoAugment 使用 Reinforcement learning 來學習控制 Augmentation 的 Policy,且從命名來看也不難猜到,是用隨機的方式調整相關參數。而 CTAugment 與 RandAugment 的差異在於,RandAugment 的隨機性會介在一個事先定義好的範圍內,CTAugment 則不受限。

Additional important factors

除了 FixMatch 算法本身相關的參數外,其實還有些像是 Regularization 的因素會影響最後的成效,就像深度神經網路要訓練時,也會有一些架構、優化器、訓練策略等環節可被調整,這在過去新的 SSL 算法提出時,常被忽略。

該論文則盡可能地量化這些因素的影響,後續呈現相關數據時,再一併介紹該論文做的設定,以及實驗過程的發現。

Extensions of FixMatch

由於 FixMatch 本身被設計得很簡潔,所以可以沿用過去其他 SSL 算法的一些技巧,像是 ReMixMatch 的 Augmentation Anchorin、Distribution Alignmen。

或是有些像是 MixUp、Adversarial perturbation 的方法,也可以拿來取代 Strong Augmentation 的做法。

相關實驗數據只有在原版論文的附錄有附上,有興趣可以參考這裡

Related work

這種 Self-training 的想法,其實已被提出數十年了,且被應於多種不同的領域,而 Pseudo-labeling 在近年常被當作算法的一部分,用來強化整體表現。至於 Consistency regularization 則有仰賴 Strong augmentation 的趨勢,就近期一系列的研究來看,成效確實不錯,也有人把相關技巧與 Distillation 做整合。

而與 FixMatch 最相關的作法是 Unsupervised Data Augmentation ( UDA )ReMixMatch,這兩個作法都有先用 Weak augmentation 取得 Label ,再強制 Strong augmentation 的 Representation 要有一致性。不過它們都沒有用 Pseudo-labeling ,取而代之的是一種 Sharpening 的作法,並鼓勵模型做出 High-confidence 的預測。而 FixMatch 透過閥值過濾 Pseudo-label 的作法,基本上與 Sharpening 有類似的效果。

另外,ReMixMatch 有刻意在計算 Loss 的部分降低 Unlabeled data 的影響程度,但 FixMatch 是以閥值過濾的作法達到作用。

基於前述的相似性,其實可以把 FixMatch 想成簡化過的 UDA 和 ReMixMatch,且在 2 者結合時,移除了多個 Component 。

該論文也把其他有相似性的 SSL 方法整理在下表:

並將它們做 Augmentation 和對 Label 處理的方法進行了比較。

Experiments

該論文在做實驗時,相關設定皆比照過去 SSL 研究的設定,並在 CIFAR10/100 、SVHN 、STL-10 和 ImageNet 這些常見的資料集上,與其他 SOTA 方法做比較。

在模型架構上有使用 3 種不同的架構:

  1. Wide ResNet-28–2 ( For CIFAR10 、SVHN )
  2. WRN-28–8 ( For CIFAR100 )
  3. WRN-37–2 ( For STL-10 )

另外,為了凸顯 FixMatch 的性能,作者們還額外做了資料量極少的實驗 ( 只有 40 筆資料有 Label ),並與 UDA 和 ReMixMatch 做比較:

基本上除了在 CIFAR-100 之外,FixMatch 都達到了 SOTA 的水準。

為了理解 FixMatch 在 CIFAR-100 表現較為遜色的原因,作者們額外把 ReMixMatch 的一些 Component 放到 FixMatch 做實驗,並發現影響最多的是 Distribution Alignment ( DA )。

這個技巧會鼓勵模型在相同類別下有更接近的預測結果,單獨將這個 DA 加入在 FixMatch 的話,可以做到 40.14 % 的Error rate,比 ReMixMatch 的 44.28 % 更優。

Barely Supervised Learning

為了測試 FixMatch 的極限,該論文還做了一個很極端的實驗,在 CIFAR-10 資料集的每個類別只使用 『 1 』 張有 Label 的資料。

一開始用隨機的方始選,效果不太好,平均準確度只有 64.28 %,作者們推論是這些隨機選的樣本品質太差。為了驗證這想法,作者們參考了一篇分析資料分佈的論文,並在每個類別選取最具代表性的 1 筆資料:

Sohn, K., Berthelot, D., Carlini, N., Zhang, Z., Zhang, H., Raffel, C., Cubuk, E., Kurakin, A., & Li, C.L. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. In Advances in Neural Information Processing Systems (pp. 596–608). Curran Associates, Inc. Figure 2.

再把 FixMatch 應用在這 10 張影像進行訓練,結果平均準確度有提升到 78 %

Ablation Study

為了理解 FixMatch 算法為何能取得較好的結果,該論文也對算法上可以調整的不同地方做了 Ablation study,這部分只在 CIFAR-10 上執行,並使用 250 張有 Label 的影像。

Sharpening and Thresholding

與 Pseudo-labeling 較相關的參數是對模型預測出的機率分佈處理的方法,這環節有 2 種做法:

  1. 透過 Threshold ( τ ) 篩出 One-hot label
  2. 用 Temperature ( T ) 控制 Sharpening 程度

作法 1. 可對應到下圖的 (a),而 (b) 則是對 2 者結合的情況做實驗:

Sohn, K., Berthelot, D., Carlini, N., Zhang, Z., Zhang, H., Raffel, C., Cubuk, E., Kurakin, A., & Li, C.L. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. In Advances in Neural Information Processing Systems (pp. 596–608). Curran Associates, Inc. Figure 3.

Threshold 數值控制的是 Pseudo-label 數量與質量之間的 Trade-off,實驗結果是設定為 0.95 時有最好的效果,這表示對模型來說,要做出更準的預測,資料的質量比數量重要。

而對照 Temperature 的設定來看,當 Threshold 數值高的時候,Temperature 的設定相對沒太大的影響。

Augmentation Strategy

另一個對 FixMatch 算法影響很大的環節是 Augmentation policy,在該論文的實驗中,有使用 RandAugment 和 CTAugment,而這 2 種方法在做完 Strong augmentation 後都有做 Cutout,這邊試著對這環節做實驗:

Sohn, K., Berthelot, D., Carlini, N., Zhang, Z., Zhang, H., Raffel, C., Cubuk, E., Kurakin, A., & Li, C.L. (2020). FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. In Advances in Neural Information Processing Systems (pp. 596–608). Curran Associates, Inc. Table 3.

從上表可看出,Cutout 和 CTAugment 必須同時使用才會取得最好的結果。

而在 Weak 和 Strong augmentation 的作法上,該論文也嘗試過幾種不同的作法:

  1. 將 Weak augmentation 替換成 Strong augmentation:
    這使得模型在訓練初期就發散,無法收斂
  2. 將 Weak augmentation 拿掉:
    這使得模型出現 Overfitting 的情況
  3. 當訓練準確度達到 45 % 時,將 Strong augmentation 換成 Weak augmentation:
    訓練過程變得不穩定,準確度崩塌至 12 %

這些嘗試顯示出目前的組合是較好的方法,且吻合一般在 Supervised learning 觀察到的現象。

Conclusion

SSL 近年有非常迅速的發展,但這些進步來自於演算法的複雜化,使得 Hyper-parameter 的調整變得更困難。

而該論文的 FixMatch 用較簡單的設計在多個資料集上達到 SOTA 的表現,甚至在極少資料量的情況有著相當高的準確度,這巧妙地搭建了 Few-shot learning 和 Semi-supervised learning 之間的橋樑。

就整體而言,作者們相信這類簡單又有效的 SSL 演算法,可以在人工標註非常昂貴又難取得的窘境中,讓 Machine learning 的技術更有機會被部屬到真實的應用中。

--

--

Ken Huang
人工智慧,倒底有多智慧?

在網路上自學的過程中,體會到開放式資源的美好,希望藉由撰寫文章記錄研究所的學習過程,同時作為回饋網路世界的一種方式。Email : kenhuang2019iii@gmail.com ,如果有任何問題都歡迎與我聯繫。