論文閱讀 NIPS 2021 — FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling

該論文將 Curriculum learning 結合 Pseudo Labeling 來改善現有 Semi-supervised learning 作法 ( FixMatch ),解決過去在訓練過程無法考量各類別學習難易度的缺點,有效提昇準確度和收斂速度。

Ken Huang
人工智慧,倒底有多智慧?
20 min readFeb 24, 2022

--

論文連結:
《 FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling 》

作者:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021).

GitHub Link

Introduction

由於 Semi-supervised learning 相關技術可善用大量的 Unlabeled data 來提升模型的學習成效,是近年相當引人注目的研究領域。

其中有 2 種主要的技巧是「 Consistency regularization 」「 Pseudo labeling 」,都能在訓練過程有效善用 Unlabeled data 的資訊 。而在 NIPS 2020 提出的 FixMatch 更是結合上述兩種技巧和 Weak/Strong data augmentation,做出了超越過往的成效。

然而像是 FixMatch 或其他 Pseudo-Labeling 、UDA 的演算法,通常都仰賴一個固定的 Thresholding 來計算 Unsupervised loss,並在訓練過程只會採用 Confidence 超過 Threshold 的 Unlabeled data,這個策略可以確保模型學到 的是 High-quality 的資訊,但同時也導致訓練過程只能善用少部分的 Unlabeled data ( 特別是在訓練前期 Confidence 不容易高的階段 ),且對於不同類別的過濾標準是完全一致的,並沒有考量不同資料在學習上的難易度差異。

為了解決上述問題,該論文提出「 Curriculum Pseudo Labeling ( CPL ) 」,主要是善用 Curriculum learning 的策略來考量不同類別的學習狀態。做法上是將原先 Pre-defined threshold 替換成可調整的 Thresholds,並根據不同類別在當下的學習狀態做動態調整。

值得留意的是,這個做法並不會增加額外的參數量、計算量,該論文的做法是直接將 Curriculum learning 應用於 FixMatch,並稱之為 FlexMatch

在實驗過程有發現 FlexMatch 具備更快的收斂速度,並達到 SOTA 的 Performance,CPL 的做法在 Label 非常稀少、任務較為困難的情況下,有很不錯的表現,詳細數據留在 Experiments 章節呈現。這邊總結一下該論文的貢獻:

  • 提出「 Curriculum Pseudo labeling ( CPL ) 」,動態的善用 Unlabeled data 來做 Semi-supervised learning,且易於整合其他 Semi-supervised learning 演算法。
  • CPL 在多種當代的 Semi-supervised learning 演算法上都能有大幅提升準確度、收斂效率的表現,而與 FlexMatch 進行整合可達到 SOTA 的水準。
  • 利用 PyTorch-based semi-supervised learning codebase 進行公平的研究比較,並將實作的原始碼開源為 TorchSSL,包含當前熱門的 Semi-supervised learning 演算法,具備易用性、可客製化等特色。

Background

過去常見的 Consistency regularization 作法會採用 L2 Loss 的形式:

B 是 Labeled data 的 Batch size;µ 是 Unlabeled data 與 Labeled data 的比例;ω 是 Data augmentation ( 上式的兩個 ω 是不同的 Data augmentation );u_b 是 Unlabeled data;p_m 表示模型的輸出的機率分佈。

與前面提到的 Pseudo labeling 技巧相結合後,就開始轉變為 Entropy minimization 的一種過程,對分類任務來說是較穩定的方法。改善後的 Consistency loss 可表述為:

H 是 Cross-entropy;τ 是個預先定義的 Threshold ( 用來過濾太像雜訊的 Unlabeled data );而 pˆ_m(y|ω(u_b)) 是 Pseudo label ( 可以是 Hard 或 Soft label )。

FixMatch 的作法也是類似形式,善用 Consistency regularization 加上 Strong augmentation 來做出較好的表現:

與前一道公式主要的差異在最右邊的 u_b 是經過 Strong augmentation ( Ω ) ,而不像左邊的 u_b 是經過 Weak augmentation ( ω ) 處理。

而前面提過可改善的 Pre-defined threshold 是上式的 τ ,作者們相信某些類別是比其他類別更難學習的,因此透過 Curriculum learning 有機會根據模型的學習狀態做更好的優化 。

FlexMatch

Curriculum Pseudo Labeling

為了讓模型在不同時間點可以對各類別有不同的 Threshold 來對 Unlabeled data 生成 Pseudo labels ,一個理想的作法是去計算各類別的 Evaluation accuracy 並對 Threshold 做 Scale :

上式的 a_t(c) 代表這模型在 t 時間點對類別 c 的準確度 a,這樣設計會有一個效果是:當某個類別準確度較低時,會連帶將過濾 Unlabeled data 的 Threshold 調低,並鼓勵模型在這類別上透過更多的 Sample 進行學習。

不過訓練過程是沒辦法使用到 Evaluation set 的,通常的作法會是從 Training set 切出另一個 Validation set,但這樣做有會有以下 2 個嚴重的問題:

  1. 在 Labeled data 非常稀少的情況下,再分割出一部分資料作為 Validation set 是非常昂貴的。
  2. 為了要在訓練過程動態調整那些 Thresholds ,會需要在訓練過程的每個時間點都評估模型當下的準確度,也就會拖慢訓練的速度 。

因此該論文在設計 Curriculum Pseudo Labeling ( CPL ) 的過程採用了另一個替代方案,讓訓練過程無須增加額外的推論,也不需要多切一個 Validation set。

而這替代方案的主要想法是:

當 Threshold 設定較高時,模型在各類別的學習成效會反映在各類別過濾後所留下的樣本數量。也就是說,若某些類別只有較少的樣本沒被過濾掉,那該類別的學習難度就是較高的,也同時表示模型在此類別的學習狀態較差。

示意圖如下:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Figure 1.

而該論文以下方公式來評估模型在各類別的學習成效:

σ_t(c) 表示的是模型在 t 時間點對 c 類別的學習成效;p_m,t ( y | u_n ) 是模型在 t 時間點對 Unlabeled data 的推論結果;N 則表示全部 Unlabeled data 的數量。

當 Unlabeled dataset 是平衡的時候, σ 越大表示模型在該類別學習得越好,再進一步透過以下公式將值域壓縮回 0 ~ 1 之間,便可用來對 Threshold 做調整:

當模型在某類別學習的好的時候,β_t(c) 會 = 1,也就讓 Threshold 的設定和預先定義的狀態相同;而當模型對某類別的學習成效較差時,就會因為 β_t(c) 介於 0 ~ 1 之間,而降低 Threshold 的設定,也就鼓勵模型在訓練過程對此類別使用更多的 Unlabeled data 進行學習。

就整個訓練過程來看,當模型在各類別的學習狀態提昇時,就會將各類別所對應的 Threshold 逐漸提高,直到最後模型在每個類別的 Threshold 都提高到一開始定義好的狀態。

值得留意的一點是,Threshold 不見得隨著訓練過程而提高,倘若 Unlabeled data 被分類至錯誤的累中,就有可能導致 Threshold 下降。

然而在該論文提出的 FlexMatch 中,新的 Threshold 作法有被用來計算 Unsupervised loss:

其中的 T_t 會隨著每次迭代而做改變。

而在 FlexMatch 的完整 Loss 中,會以加權的方式結合 Supervised loss:

而其中的 L_s 就是訓練在 Labeled data 上的 Supervised loss:

Threshold warm-up

在實驗過程中,作者們有觀察到這機制在訓練前期會盲目地對大多數的 Unlabeled data 進行預測,且結果會因為初始化的不同而偏向某一個特定類別,就又會陷入像是過去所謂 Confirmation bias 的窘境。

這也讓作者們覺得這階段所預測出的學習狀態是不值得信任的,所以在對 σ 做 Normalize 的時候加入了 Warm-up 的機制:

其中的這部份:

是指還沒被用到的 Unlabeled data,這個改動確保了訓練初期的 β_t(c) 會逐漸從 0 開始提昇,直到這些沒被用到的 Unlabeled data 數量不再主導分母的部份。( 這段時間的長短會取決於 Unlabeled data 數量和各類別的學習難度 )

Non-linear mapping function

在使 Threshold 更彈性的這個公式中:

Threshold 會因為經過 Normalize 的學習成效 β_t(c) 而產生一種線性的 Mapping,但在訓練初期,有可能因為模型預測的狀態還不穩定,而導致 β_t(c) 有大幅增減的情況,這在真實訓練模型的過程不見得是好的狀態。而在訓練中後期,當模型在某個類別已經學得不錯時,Threshold 的浮動就不會太大。

該論文提出一種非線性的 Mapping function 來讓 Thresholds 可以在 β_t(c) 介於 0 ~ 1 之間的時候,變成是非線性上升的曲線:

原版的公式可以看作這個 M() 的其中一種特例,也就是讓這個 M() 在最大值為 1/τ 的狀態下單調的提昇 β_t(c)。

為了避免額外增加參數量,作者們僅考量值域介於 0 ~ 1 的 Mapping function ,並直覺地選了一個 Convex function:

目的是讓 Thresholds 在 β_t(c) 數值小的時候可以成長的較緩慢,而在 β_t(c) 數值大的時候可以更敏感。後續也有實驗比較其他的 Mapping functions 。

而 FlexMatch 的整個演算法流程可表示為:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Algorithm 1.

Experiments

該論文在 CIFAR10/100 、SVHN 、STL-10 和 ImageNet 等常見的 SSL 資料集上驗證了 FlexMatch 和其他加上 CPL 演算法的成效,主要比較的對象包含 Pseudo-Labeling 、UDA 和 FixMatch ( 因為他們都有預先設定 Threshold 的作法 ),以及 Fully-supervised 的作法 。大多實驗設定都 Follow 著 FixMatch 的作法,細節請參照原文

Main results

在 CIFAR-10/100 、STL-10 和 SVHN 等資料集的分類任務 Error rate 如下表:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Table 1.

不難看出 FlexMatch 在多數的成果上都有 SOTA 的表現,其中在只有 40 or 1000 筆 Labeled data 的時候,Error rate 是最低的。

這邊條列一下 FlexMatch 所擁有的優勢:

  • CPL 的作法在 Labeled data 非常有限的情況下有較好的表現
  • CPL 可以改善現有的 SSL 演算法 ( 可以看到有 Flex- 前綴的實驗結果 )
  • CPL 在較複雜的任務有較佳的表現 ( STL-10 資料集的 Unlabeled data 有新種類的物件,會導致任務便得較困難 )

不過在這部份的實驗中, FlexMatch 對 SVHN 資料集的表現就沒那麽理想,作者們對此進行分析後,發現 SVHN 資料集在各類別的資料量是不平衡的,有些類別的資料少到永遠不會讓該論文提出的方法對這類別推算的學習狀態趨近於 1 ,也就會一直允許有 Noised pseudo-labeled samples 存在,並不斷在整個訓練過程中被學習。

然而反觀表現較佳的 FixMatch,由於 Threshold 固定在 0.95 來過濾掉那些 Noised pseudo-labeled samples,雖然也會過濾掉一些較難學習的樣本,但由於 SVHN 資料集本身是個相對簡單的任務,模型較容易學習也就較容易對 Unlabeled data 生成 High-confidence 的 Predictions,所以設定一個固定、數值較高的 Threshold 就是一個相對沒問題的作法。

Results on ImageNet

該論文也是著將 CPL 的作法放到 ImageNet-1K 的資料集上,試著在更真實、複雜的資料中驗證其效能。

作法是在 1000 個類別上隨機抽出 100 筆 Labeled data ( 不到整體的 8% ),並與 FixMatch 比較相同迭代次數 ( 2²⁰ 次 ) 時的 Top-1 、Top-5 Error rate:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Table 2.

結果也是 FlexMatch 比較好 ( 但這並不是 FlexMatch 在這資料集最好的表現,即便訓練了 2²⁰ 次的迭代後,模型仍然尚未收斂,只是礙於有限的運算資源只好作到這個程度 )

Convergence speed acceleration

另一個 FlexMatch 的優勢是它的收斂速度,下圖是比較 FlexMatch 和 FixMatch 在 CIFAR-100 資料集上只用 400 筆 Labeled data 的 Loss 和 Top-1 Acc.:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Figure 3 (a) & (b).

不論是看 Loss 還是 Acc. ,FlexMatch 都有較快收斂的趨勢,主要差異就在於 CPL 不像過去 Pre-defined threshold 作法,能在訓練初期將更多的 Unlabeled data 囊括進訓練流程,促使 Gradient 在計算時能更往 Global optimum 去。

另外該論文也針對 CPL 的特性,將 FlexMatch 和 FixMatch 在 CIFAR-10 資料集的訓練過程記錄下來,並對各類別的準確度進行比較:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Figure 3 (c) & (d).

在 200K 次的迭代後,FlexMatch 的準確度遠比 FixMatch 來得高,94.3% 的準確度甚至比 FixMatch 訓練 1M 次迭代之後還要高,這更進一步證明了 CPL 對較難學習的類別有它的效用,且能改善整體的學習成效。

Ablation study

這個章節主要針對以下 3 個 FlexMatch 的 Componenes 進行實驗:

  • Upper limit of thresholds τ
  • Mapping functions M(x)
  • Threshold warm-up

Threshold upper bound

針對 τ 的定義,實驗是做在 CIFAR-10 資料集並只有 40 筆 Labeled data 的情況:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Figure 4 (a).

很明顯的,最好的設定就是 0.95,不論增減都會造成模型表現變差 。

Mapping function

針對 Mapping functions M(x) 的實驗一樣是做在 CIFAR-10 資料集並只有 40 筆 Labeled data 的情況:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Figure 4 (b).

被拿來實驗的 Function 包含以下 3 種:

  • Concave:
  • Linear:
  • Convex:

這實驗證明了 Convex 的作法會比較好,雖然更進一步調整 Convexity 的程度會有可能帶來更好的改善,但該論文沒有繼續往下研究。

另一個可能訓練初期過濾 Unlabeled data 的作法是設計 Threshold 的下限,例如從原本的 0 ~ 1 改為 0.5 ~ 1,就能在一開始就過濾 Confidence 較高的 樣本。但由於這樣做會增加另一個 Hyper-parameter ,所以作者們就沒有將它設計為 FlexMatch 的一部分,但在實驗過程確實有發現此方法會帶來些微的效能提昇。

Threshold warm-up

針對 Threshold warm-up 機制的分析則是分別在 CIFAR-10 ( 40 labels ) 和CIFAR-100 ( 400 labels ) 資料集做實驗:

Zhang, Bowen, et al. “Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling.” Advances in Neural Information Processing Systems 34 (2021). Figure 4 (c).

分別帶來將近 0.2% 與 1% 的改善,證明了訓練初期在各類別善用大量的 Unlabeled data 做訓練會是一個比較好的方法 。

Comparison with class balancing objectives

CPL 的作法有一種在每個 Batch 平衡各類別 Unlabeled samples 的作用,類似個效果其實也能透過讓 Class distribution 在每個 Batch 趨近於 Uniform distribution 。

因此該論文還做了一種實驗是在 FixMatch 加入另一個額外的 Loss:

pˆ_c 是模型對 c 類別所有 Sample 預測出的機率分佈的平均;而 q 是 Uniform distribution,q_c = 1 / C。

這實驗在 CIFAR-10 ( 40 labels ) 設定的結果是 Error rate 7.16%,而 FixMatch 是 7.47%±0.28 ,FlexMatch 則是 4.97%±0.06。

但這種作法會需要資料本身在每個 Batch 就是平衡的狀態才會有這種效果,而 CPL 就沒有這種限制,所以該論文提出的方法是更彈性的,也減少了人為介入調整 Threshold 的麻煩。

Related Work

在 Semi-supervised learning 任務上結合 Curriculum learning 是近年常見的作法,不論是在影像分類語意分割多任務學習都有類似的嘗試。甚至在文本的情感分析任務也有動態調整 Threshold 的類似作法,或是在語意分割任務上有用額外的分類器來自動化調整 Threshold ,並用來處理 Domain inconsistency。較近期發表在 AAAI 2021 的這篇論文,就結合了 Self-traning 的方法,在訓練過程逐漸提昇 Threshold ,並達到了 SOTA 的成果 。

Conclusion and Future Work

該論文提出了「 Curriculum Pseudo Labeling ( CPL ) 」,以 Curriculum learning 的機制善用 Unlabeled data 來做 Semi-supervised learning,除了在效能上有顯著改善外,也提昇了收斂速度,且幾乎沒有增加額外的訓練負擔。

而該論文基於 FixMatch 做改善的結果稱為 FlexMatch,在多個 SSL 的 Benchmark 上都達到 SOTA 的水準,作者們期望未來會改善這方法在面臨 Long-tail scenario ( Unlabeled data 在各類別極度不平衡 ) 的表現 。

--

--

Ken Huang
人工智慧,倒底有多智慧?

在網路上自學的過程中,體會到開放式資源的美好,希望藉由撰寫文章記錄研究所的學習過程,同時作為回饋網路世界的一種方式。Email : kenhuang2019iii@gmail.com ,如果有任何問題都歡迎與我聯繫。