Sharpness-Aware Minimization (SAM): 簡單有效地追求模型泛化能力

Jia-Yau Shiau
AI Blog TW
Published in
13 min readFeb 23, 2021

--

在訓練類神經網路模型時,訓練目標是在定義的 loss function 下達到一個極小值 (minima)。然而,在現今的運算資源下,類神經網路模型多半使用了過多的參數 (overparameterize),因此 training loss 低時也無法保證模型的泛化 (generalization) 能力。換句話說,不見得是表現最好的模型。

大部分的時候研究人員相信-收斂在較為平坦 (flat) minima 的模型,比尖銳 (sharp) 的 minima 具有更好的泛化能力。Sharpness-Aware Minimization (SAM) 是 Google 研究團隊發表於 2021年 ICLR 的 spotlight 論文,提出在最小化 loss value 時,同時最小化 loss sharpness 的簡單且有效方法。 SAM 在各個面向的實驗數據都相當亮眼,不只是展示了 generalization 能力,也在部分任務中刷新了 state-of-the-art 。

Cover photo by Canva

文章難度:★★☆☆☆
閱讀建議: 文章前段介紹了 flat 與 sharp 極值 (extrema) 的概念,閱讀時可以與 bias & variance tradeoff (under-fitting & over-fitting)做一些連接。中後段著重介紹 Sharpness-Aware Minimization (SAM) 的基本理論與實踐方法,並不會摻雜過多的數學。
推薦背景知識: optimization, non-convex optimization, gradient decent, underfitting and overfitting, bias and variance tradeoff, model regularization, ResNet.

Flat or Sharp Extrema

深度學習中模型與損失函數 (loss function) 形成的高維曲面通常都是非凸函數 (non-convex function) 。不像凸函數的 optimization,在 non-convex 的最佳化中達到全域 (global) 最佳解是困難的。因此,在尋找一個區域極值 (local extrema),或說損失低點 (loss minima) 時,會希望可以達到一個泛化能力較好的位置。

Convex & Non-convex

所謂的 convex 顧名思義指的是一個凸起的面,常聽到的用語包含凸函數 (convex function) 與凸集 (convex set) 兩種。以 convex set 解釋即在 convex set 中的任兩點中的連線必定完全落在 convex set 中,反之則為 non-convex set 。而 convex function 則是以 function 的形式描述這個特性。

--

--