[資料分析&機器學習] 第5.4講: 機器學習進階實用技巧-正規化

Yeh James
JamesLearningNote
Published in
4 min readDec 24, 2017

--

Overfitting的狀況

在實務上我們在Train model時常常會遇到Overfitting的問題,也就是Model在Training的Data正確率很高,但是拿到Testing Data的時候錯誤率卻很高。背後的主因是我們真正需要的Model比我們Train出來的Model還要簡單,也就是Train出來的Model太複雜了!如上圖所示假設我們需要的Model是一條回歸的直線,但是Train出來的Model為了讓在Traing Data中錯誤率最小化,因此Model變得奇形怪狀,這樣的Model拿去新的資料中錯誤率就會很高

這時候通常有幾種解決方法:

  1. 收集更多的Training Data
  2. 減少資料的維度(特徵)
  3. 使用更簡單的Model
  4. 對現有的Model加上使用L1 or L2正規化(懲罰penalty)

前面有提到說我們Train出來的Model其實就是在空間中的一個多項式,像是: W0*X0 + W1*X1 +…..Wn*Xn 複雜一點的Model可能是: W0*X0 + W1*X1²+W2*X2³ + … + Wn*Xn¹⁰ ,如果我們要降低這些多項式Model的複雜程度,最常見的方式就是限制W的範圍,讓W越小越好,甚至變成0。

而限制W的方式主要有兩種

  1. L1正規化
  2. L2正規化

在Train Model時都是要找一組W讓整體的錯誤率最小,通常稱作最小cost,這個值就如圓心的點。

L1正規化就是在最小Cost的公式加上

L1正規化

整體最小就變成

L2正規化則是在最小Cost的公式加上

L2正規化

整體最小就變成

這兩者的差別就很數學了,L1比較容易造成稀疏解(也就是很多的Wi會等於0),以上面的例子來說在整體最小的情況下w2 =1, w1=0,這邊不深入討論兩者細微的差異了。原則上比較常用的正規化是L2,像是Scikit-Learn 內建的Penalty就是使用L2正規化。

Scikit-Learn Logistic Regression文件

感謝你閱讀完這篇文章,如果你覺得這些文章對你有幫助請在底下幫我拍個手(長按最多可以拍50下手)。

[Python資料分析&機器學習]這系列文章是我在Hahow上面所開設課程的講義,如果你是新手想著看影片一步一步學習,可以參考這門課:https://hahow.in/cr/pydataml

如果你對什麼主題的文章有興趣的話,歡迎透過這個連結告訴我:https://yehjames.typeform.com/to/XIIVQC
有任何問題也歡迎在底下留言或是來信告訴我: yehjames23@gmail.com

參考閱讀

  1. Python機器學習
  2. [莫烦]L1 / L2 正规化 (Regularization)

--

--