[資料分析&機器學習] 第5.4講: 機器學習進階實用技巧-正規化
在實務上我們在Train model時常常會遇到Overfitting的問題,也就是Model在Training的Data正確率很高,但是拿到Testing Data的時候錯誤率卻很高。背後的主因是我們真正需要的Model比我們Train出來的Model還要簡單,也就是Train出來的Model太複雜了!如上圖所示假設我們需要的Model是一條回歸的直線,但是Train出來的Model為了讓在Traing Data中錯誤率最小化,因此Model變得奇形怪狀,這樣的Model拿去新的資料中錯誤率就會很高
這時候通常有幾種解決方法:
- 收集更多的Training Data
- 減少資料的維度(特徵)
- 使用更簡單的Model
- 對現有的Model加上使用L1 or L2正規化(懲罰penalty)
前面有提到說我們Train出來的Model其實就是在空間中的一個多項式,像是: W0*X0 + W1*X1 +…..Wn*Xn 複雜一點的Model可能是: W0*X0 + W1*X1²+W2*X2³ + … + Wn*Xn¹⁰ ,如果我們要降低這些多項式Model的複雜程度,最常見的方式就是限制W的範圍,讓W越小越好,甚至變成0。
而限制W的方式主要有兩種
- L1正規化
- L2正規化
在Train Model時都是要找一組W讓整體的錯誤率最小,通常稱作最小cost,這個值就如圓心的點。
L1正規化就是在最小Cost的公式加上
整體最小就變成
L2正規化則是在最小Cost的公式加上
整體最小就變成
這兩者的差別就很數學了,L1比較容易造成稀疏解(也就是很多的Wi會等於0),以上面的例子來說在整體最小的情況下w2 =1, w1=0,這邊不深入討論兩者細微的差異了。原則上比較常用的正規化是L2,像是Scikit-Learn 內建的Penalty就是使用L2正規化。
感謝你閱讀完這篇文章,如果你覺得這些文章對你有幫助請在底下幫我拍個手(長按最多可以拍50下手)。
[Python資料分析&機器學習]這系列文章是我在Hahow上面所開設課程的講義,如果你是新手想著看影片一步一步學習,可以參考這門課:https://hahow.in/cr/pydataml
如果你對什麼主題的文章有興趣的話,歡迎透過這個連結告訴我:https://yehjames.typeform.com/to/XIIVQC
有任何問題也歡迎在底下留言或是來信告訴我: yehjames23@gmail.com