【機器學習 03】機器要如何學得更好?Part 1
在本篇文章中,我們會提到如何幫助機器學得更好的一些技巧,分別為·:
- Nonlinear Transformation
- Overfitting
- Validation
- Regularization
還會介紹三個 learning principles。
在本篇文章中,我們先介紹 Nonlinear Transformation、Overfitting 和 Validation。
Nonlinear Transformation
線性假設是有限制的,並非所有資料都是線性可分,也許能夠用一個圓形分割。
因此可以試著將資料 X 投射到 Z 空間中,經過線性轉換,在 X 空間中 circular separable 的資料,到 Z 空間就可以變成線性可分。
例如圓形的表示式是 h(x) = sign(0.6 - x₁² - x₂²),若用 z₁ 代替 x₁²、z₂ 代替 x₂²,經過 nonlinear feature transform 後,得到 h(z) = sign(0.6 —z₁ — z₂),則在 z 空間即是線性可分。
然而,這個解法只能轉換某些特定的圖形,我們需要更概括的解法,也就是直接使用 nonlinear model,而非 nonlinear transformation + linear models。
在使用 nonlinear models 有幾個需要注意的事情:
- 計算和儲存的花費
- 模型複雜度的代價
- 模型泛化性(generalization)的問題
- 視覺化的危險性:在二維平面上很容易視覺化,所以能夠挑選比較好的模型,但是這是視覺化後的主觀判斷,要避免用這種方式選擇
記得永遠都要先用 linear model 嘗試,因為比較簡單、有效率、安全且比較可行。
Hazard of Overfitting
不好的 generalization 即為 overfitting,in sample 的 error 很小,但是 out sample 的 error 很大。若將模型的複雜度降低,則有可能解決此問題。
而 underfitting 則是 in sample 的 error 很大,out sample 的 error 也變大,此時傾向於將模型的複雜度增加。
產生 overfit 的原因:
- data size 太小
- 隨機性錯誤(stochastic noise)太大
- 確定性錯誤(deterministic noise)太大
- 當資料不夠多時,excessive power 大
確定性錯誤(deterministic noise):無法因為模型改變而消失的錯誤,被給定的樣本資料決定
處理 overfitting 的方法:
- 先從簡單的模型開始
- 資料清理、修剪,例如若不同類別的資料距離太近,則可以清理資料(修正標籤)或是修剪資料(移除此資料點)
- 擴增資料庫,資料偏移、旋轉
- validation
- regularization(會在 Part 2 介紹)
Validation
模型有太多要學習的,即使是最簡單的二元分類問題,也要決定:
- 要選什麼模型:linear regression 或 logistic regression
- learning rate
- feature transformation
- regularization:要使用什麼 regularizer
⋯⋯等等,許多要調整的選項
要如何選擇模型呢?
out sample 的 error:是未知,也不能視覺化地選擇。
in sample 的 error:很危險,因為有可能會是不好的 generalization。
test sample error:不可行且欺騙的。
這些樣本組都沒辦法使用,因此我們提出一個 validation sample,可行並且合法地欺騙。如果 validation sample 從未在資料集中出現,則能夠很好地驗證模型效果。
在實作時,我們將資料集 𝒟 分成 𝒟ₜᵣₐᵢₙ 和 𝒟ᵥₐₗ,測試不同的模型,依照 validation set 的 error 選出表現最好的模型後,再用所有資料重新測一次,通常會得到更好的結果。
至於要如何切出 validation set 呢?林軒田老師的建議是使用 k = N / 5。
Leave-One-Out Cross Validation (LOOCV)
LOOCV 也是一個常見的 cross validation 切分方式,每一次只取一個資料點當作 validation set(k = 1),經過數學證明推導可知 LOOCV 的 error 能夠跟 out sample 的 error 很接近。
但是這方法也有幾點壞處:
- 需要 N 次額外的模型訓練,在實際應用時,可能不是每次都有這些資源
- 可能因為單一資料點過於偏誤,而影響模型表現
因此在真實訓練時,常用的是 k-fold cross validation,通常都設定 k = 10。
想要更深入了解的話,別忘了去最上面的目錄看其他章的課程筆記!
喜歡這篇文章或是對你有幫助的話,別忘了拍手給我鼓勵哦 👏🏻
參考資料
- 林軒田,機器學習基石與技法:https://www.youtube.com/c/hsuantien/playlists