使用 TensorFlow 了解 Dropout

Airwaves
Airwaves
Sep 5, 2018 · 6 min read
Photo by Bronwyn on Unsplash

Dropout 的方法與原理

在 🔗過擬合與欠擬合 這篇文章中提到模型的容量資料的多寡皆會影響訓練的結果,如果模型容量過高或者是訓練資料過少很容易造成過擬合,在這篇文章中我們要談到一種對抗過擬合的方法:丟棄法 (Dropout)。

Dropout 是一種對抗過擬合的正則化方法,在訓練時每一次的迭代 (epoch)皆以一定的機率丟棄隱藏層神經元,而被丟棄的神經元不會傳遞訊息,例如 Fig 1 的每一層以 0.5 的機率丟棄神經元,所以再向前傳播時紅色被打叉的神經元不會傳遞訊息。

Fig 1. Dropout (取自 Tikz/Dropout)

此外,在反向傳播時,被丟棄的神經元其梯度是 0,所以在訓練時不會過度依賴某一些神經元,藉此達到對抗過擬合的效果。

但是,如果在訓練時以機率 p 丟棄神經元,而測試時不會丟棄神經元,因此,會造成測試的結果比訓練大 1/(1−p) 倍,所以為了保持輸出的期望值不變,會在測試時將神經元向前傳遞的訊息乘以 1−p

而我們在這個範例中所使用的 Dropout 是變種的 Inverted Dropout,與一般 Dropout 不同的是 Inverted Dropout 在訓練時會將結果除以 1−p,讓訓練時的期望值維持不變。

從零開始實現 Dropout

在這個範例中,我們會使用 MNIST 這個手寫數字的資料集,透過多層感知機實現辨識手寫數字,最後會使用 dropout 對應過擬合的問題,進而提升測試的準確率。

引入相依套件與資料

首先,我們要引入會用到的 3 個套件,numpy、tensorflow、matplotlib,再從 tensorflow 的官方範例資料集下載 MNIST。

定義 Dropout 函式

因為我們使用的是 inverted dropout,所以在丟棄神經元後,為了讓輸出的期望值保持不變,會將所有的神經元皆除以保留的機率 keep_prob。

定義類神經網路模型

我們定義一個 784 x 1000 x 10 的類神經網路,因為 MNIST 的影像長寬為 28 x 28,所以輸入節點為 784 個,而隱藏層的神經元數量為 1000 個,最後為 10 個節點,經過 softmax 後輸出分別為 0 ~ 9 的機率。

在使用 dropout 時,要特別注意的是只有在訓練時需要 dropout 神經元,而測試時不使用 dropout,因此我們可以用 tf.placholder with_default() 定義保留神經元的的機率,並在測試使用預設 1.0 的機率保留所有的神經元。

定義超參數

在這個範例中,除了設定批量大小(batch size)、迭代次數(epochs)、步幅(learning rate) 之外,我們還要再定義保留神經元的機率 keep_prob,但是為了實現程式碼,所以將 keep_prob 這個超參數的定義移至上一個部份。

定義損失函數、優化器與計算準確率的函數

在這個範例中,我們使用 softmax_cross_entropy_with_logits 這個函式,如果想要自己實做 cross entropy 與 softmax,需要特別注意 softmax 可能最造成數值爆炸的問題,因為一旦遇到像是 e⁵⁰ 這種數值,Python 無法容納如此龐大的數值,因此會發生 nan (not a number)。

訓練類神經網路模型

為了實驗 dropout 是一個有效的方法,在訓練時分別對沒有使用 dropout使用 dropout 的類神經網路都訓練 10 次,並取得各自的準確率作為比較的基準。

在這個範例中使用的是小批量的訓練方法,而 MNIST 的訓練集為 55000 筆資料,所以每一個次的迭代 (epoch) 總共訓練 55000/200=275 次。

繪製準確率比較圖

最後,我們用 matplotlib 視覺化排序後的準確率,可以發現使用 dropout 的類神經網路比沒有使用 dropout 的類神經網路有更好的準確率,所以從實驗中可以證實 dropout 是一個可以對抗過擬合的方法,提升模型的泛化能力。

結論

從實驗中可以發現 dropout 是一個有效降低過擬合的方法,與權重衰減的精神相似,皆是在訓練模型時,不過度依賴某些權重,進而達到正則化的效果。

但是要特別注意的是,dropout 只能在訓練時使用,所以會造成測試時向前傳播的訊息大於訓練時向前傳播的訊息,通常會在測試時乘以 1−p 這個問題。而在這個範例中,我們使用的是 inverted dropout,在訓練時將傳遞的訊息除以 1−p,讓實作更為簡潔。

參考資料

延伸閱讀

手寫筆記

學習永無止盡,我們一起學習。

Airwaves

Written by

Airwaves

人工智慧\電腦科學初學者,喜歡分享各式各樣的知識。

手寫筆記

學習永無止盡,我們一起學習。

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade