投影梯度下降法解正則化問題:以Lasso回歸為例

Edward Tung
數學、人工智慧與蟒蛇
11 min readMay 11, 2019

Solving regularized problem via proximal gradient algorithm:Take Lasso regression as an example

【前言:約束問題(Constrained Problem)】

本文會以對於梯度下降法有一定認識為前提來撰寫,如果對於該方法沒有太多了解,建議可以參考我上一篇文章。

進入正題,在最佳化問題中,我們往往會對目標函數加入一些限制條件。舉個生活化的例子來說,我們希望自動化安排每周的工作內容,當然,目標就是最大化這些工作所帶來的收益,然而,我們可能希望每周總工作時間不超過一定定量,或是某些工作應該優先被完成等等,這些就是所謂的約束條件,即在最佳化問題中,最佳解不可以違背這些條件。

在機器學習以及統計的領域中,約束條件的目的多有不同。最常見的是作為懲罰項使用,回憶一下我們提過的線性迴歸最小平方解,我們可以透過加入一項限制條件,來給予目標函數一些 "懲罰",意思是當目標函數走到一個極端時,我們不希望他繼續朝這個方向走下去,因為可能會出現如Overfitting等問題,舉例來說,我們可以加入一個L1懲罰項,也就是L1-norm:

這邊分兩步驟講解,首先,上述的改寫叫做Lagrangian Multiplier Method,對於 f(x) 為目標函數,g(x)為限制函數的情況下,一般形式會寫成:

Source : Wikipedia

這個方法在大一微積分應該有詳細介紹過推導過程,如果對此方法不熟悉,可以參考我下面附上的文章,我們這邊只說明一個直觀理解的方式,簡單來說,透過加入 Multiplier,在實務上多半為手動加入的懲罰係數,我們可以給予原目標函數一個懲罰,比如上面線性迴歸的例子,如果 L1-Norm > 0 (注意絕對值相加不會 <0),我們的目標函數會變大,也就不符合我們對它最小化的預期。而 Multiplier 的作用則在於,當我們在迭代收斂的過程中,扮演減速的緩衝,比如說在梯度下降的過程,每一次的更新不見得都在限制區域內執行,因此我們透過該乘數,來限制梯度更新的方向。

那麼下一個問題是,加入該懲罰項到底有甚麼意義?為甚麼不是別的懲罰方式,比方說很直觀地限制所有係數不可大於1? 當然,不同的懲罰項會有不同的作用,其中範數懲罰項是比較泛化的約束條件,比如L1, L2-Norm,或是比較少見的L0-norm,對於L1, L2兩種懲罰項的解釋,我補充一篇文章在下面,這邊只簡單說明一下L1的優缺點。

L1 懲罰項的目的在於,可以生成一個稀疏權重矩陣 (Sparse Weight Matrix),作用在於在迭代的過程中非常一翻兩瞪眼,只有 >0 或是 <= 0 兩種選擇,也就是多數的係數是沒有用的,這也能夠幫助我們篩選係數,比如說在輸入的變數非常多時,為了避免被過多變數干擾,我們可以選擇L1正規化去篩選係數,得到比較精簡的預測結果。但相反地,對於一些貢獻似有若無的係數,我們可能很武斷地將其拋棄,整體而言預測效果也許就不如其他種類的懲罰項來的優異。

最後還有一個問題是,當我們這樣改寫以後,我們的線性迴歸就不再是線性,因此很難透過線性運算找到閉式解(Closed-form Representation),我們就要透過一種特殊方法來找到它的最佳解。

因此,本篇文章將會用 Lasso 回歸,也就是加入 L1懲罰項的回歸問題來展示新的方法,投影梯度法 (Proximal Gradient Algorithm)。

【KKT條件(Karush-Kuhn-Tucker Condition)】

對於非線性函數,只要目標函數與限制函數皆為可導函數,我們可以透過KKT條件去檢驗之。假定 f(x) 為目標函數,g(x)為限制函數,則:

也就是說,我們可以透過這四個條件去得到最佳解,滿足KKT條件等於要找到最佳解的充要條件,在上面的式子中,多出了一個 d 變數,這個變數稱為 Dual Variable,也就是我們的 Lagragian Multiplier,詳細的解說可以參考線代啟示錄的文章:

【投影梯度法(Proximal Gradient Algorithm)原理】

有了以上的先備知識以後,我們接著進入到投影梯度算法。這個算法的基礎概念是,雖然原先的梯度下降法遇到非線性問題會有問題,因為我們不能保證每一次梯度下降的更新值,都能夠滿足我們的限制條件。那麼,我們就透過將梯度更新後的值 "投影" 到受限集合上,而投影出來的新位置,就是我們要更新的方向,首先,我們先看一張圖:

Source : Carnegie Mellon University ; Convex Optimization / Ryan Tibshirani

上圖中,紅色方塊代表限制區域,當我們在原函數上執行梯度下降以後,如黑色虛線所示,可能會偏離限制區域,此時,透過投影,我們可以將黑色虛線的向量方向變更為其在限制域的邊界上移動,也就是紅色實線的方向,這樣一來就得到了更新方向,雖然並非在原函數上改進最多,但是是在限制條件下相對改進得最好的方向,我們稱這樣的做法為投影梯度算法(PGA)。

投影方向應該如何計算呢? 我們先介紹一個新名詞,投影算子(Proximal Operator),定義如下:

注意到上式中如果投影目標 g(x) 是 Convex Set 的指示函數(定義在集合 X 上的函數,有哪些屬於某一子集A),則相當於在 Convex Set 上,最接近 x 的點,此外,我們可以透過一些方法將 x 改寫成兩個投影向量的組合。

上式中,g*(x) 稱為 g(x) 的共軛函數,意義上即為對於給定的 x,使其線性函數與原函數差異的最大值,這邊僅需知道定義即可,我們不打算在此著墨太多,該函數的主要用途,在於幫助我們推導投影算子,因此對於詳細的證明,可以參考以下這篇文章:

接著,我們來看我們的目標,也就是 L1-正則化的函數,我們想要知道的是,如何得到對 L1-正則化 的投影算子?

推導步驟略顯複雜了點,但整體而言並不難理解,我們可以發現對於L1-正則化的投影算子,當 x 的絕對值超過 Regulation Multiplier 時,我們將其往反方向移動,比如 x > λ,我們就將 x 減去 λ 的值,如果沒有超過,則直接設為 0,可以發現這是一個非常嚴格的算法,取決與 Multiplier 的大小,會讓 x 不斷往某方向收斂,直到每次更新不再下降為止。

整體而言,投影梯度方法的好處是,只要函數可以寫成 l(x) + g(x) 的形式,就能夠很簡單地被迭代求解,但仍然要求 l(x) 是differentiable,而 g(x) 如上面證明,可以用次梯度(Sub-differential)的方式去求得即可。

【投影梯度算法實現】

接下來,我們將進入到實作的環節。首先介紹基礎概念,在經過一般的梯度下降以後,將其投影到限制函數(為凸函數)上,就得到投影梯度算法:

我們可以照著上面的方式去進行迭代,這個算法因為其計算與概念簡單,用來解限制條件下的最佳化問題是非常強力的工具。回到我們線性規劃的例子,因為我們已經知道梯度下降的迭代方法,因此我們可以改寫成:

PGA Iteration Method

一樣是看上去複雜實際上並不困難的展開,第一條的展開方式我們已經在上一篇梯度下降的時候提到,因此我們將第二條改寫成 Proximal Operator 的標準形式,並加入梯度的 Closed-From Representation,就可以得到以上的推導方式。此外,細心的讀者可以發現,透過投影梯度下降我們避免了直接與約束條件的梯度開戰,避免了處理在 x = 0 的不可導問題。

接著,我們進一步來看步長選擇,一樣我們透過Lipschitz Constant就可以得到一個合理的 Upper bound,且因為避免了直接對約束條件求導,因此步長選擇與一般梯度下降的步長選擇一致!

此外,因為是近端梯度法,不適用梯度小於某值就停止的條件(可以用上一篇的方法自行推導看看),我們採取另一種方式。

也就是說,當我們更新的變化值小於一定程度的時候,就可以停止迭代了。

【Acceleration:快速投影梯度法(Fast Proximal Gradient Algorithm)】

在 Nesterov 1988年的論文中,提到了一種加速投影梯度算法的方式,在原先的投影梯度算法上,多引進了一個步驟:

這種方法也稱作 FISTA (Fast Iterative Shrinkage Thresholding Algorithm),能夠在某些情況有效加速運行效率,核心思想是每次執行完迭代後,下一次的迭代會由一個推斷出來的 z 點開始:

Source : UCLA / L. Vandenberghe

這樣做法的好處是,理論上可以有 O(1/K²) 的收斂速率,在大多數時候都比一般的投影梯度法的 O(1/K) 要快得多,。比方說在 l1-norm linear regression 問題中,可以更有效收斂:

Source : UCLA / L. Vandenberghe
  • 開始實際寫Code前的題外話,Lasso迴歸配合FISTA似乎可以用在圖片去模糊化,大體是因為可以篩選圖片特徵,不過我還沒嘗試過,之後看能否比較一下 CNN 跟此方法的一些差異。

【實作:解 L1-Regularized Linear Regression Problem】

這邊要計算的量比較大,我將一些函數建立成另外一個 class,把訓練、預測等放在主要的 class。這邊為了展示比較多東西,採用的是 Lipschitz 固定步長配合 FISTA 算法來估計 L1-線性回歸問題 (Lasso問題)。先上代碼:

我們再把誤差與迭代次數畫出來看看:

可以看到我們的程式成功執行,並得到以下的結果:
Iterations : 586
Loss : 2.3608920021921946
MSE : 0.00016522783472280948
Coef : [0.023, 0.048, 2.034, 0.056, -1.931, 0.036, 2.026, -1.951, 0.035, 2.069]

單以MSE來評估的話,效果甚至會比上一篇的單純迴歸方程好,這是因為 L1 模型通常是被用來產生稀疏矩陣,也就是會觀察到某些係數為 0,實務上的解釋就是某些特徵對該模型重要性不大。

--

--

Edward Tung
數學、人工智慧與蟒蛇

Columbia Student || 2 yrs of data scientist and 1 yr of business consultant experience