使用 TensorFlow 了解 Dropout

Leo Chiu
手寫筆記
Published in
6 min readSep 5, 2018
Photo by Bronwyn on Unsplash

Dropout 的方法與原理

在 🔗過擬合與欠擬合 這篇文章中提到模型的容量資料的多寡皆會影響訓練的結果,如果模型容量過高或者是訓練資料過少很容易造成過擬合,在這篇文章中我們要談到一種對抗過擬合的方法:丟棄法 (Dropout)。

Dropout 是一種對抗過擬合的正則化方法,在訓練時每一次的迭代 (epoch)皆以一定的機率丟棄隱藏層神經元,而被丟棄的神經元不會傳遞訊息,例如 Fig 1 的每一層以 0.5 的機率丟棄神經元,所以再向前傳播時紅色被打叉的神經元不會傳遞訊息。

Fig 1. Dropout (取自 Tikz/Dropout)

此外,在反向傳播時,被丟棄的神經元其梯度是 0,所以在訓練時不會過度依賴某一些神經元,藉此達到對抗過擬合的效果。

但是,如果在訓練時以機率 p 丟棄神經元,而測試時不會丟棄神經元,因此,會造成測試的結果比訓練大 1/(1−p) 倍,所以為了保持輸出的期望值不變,會在測試時將神經元向前傳遞的訊息乘以 1−p

而我們在這個範例中所使用的 Dropout 是變種的 Inverted Dropout,與一般 Dropout 不同的是 Inverted Dropout 在訓練時會將結果除以 1−p,讓訓練時的期望值維持不變。

從零開始實現 Dropout

在這個範例中,我們會使用 MNIST 這個手寫數字的資料集,透過多層感知機實現辨識手寫數字,最後會使用 dropout 對應過擬合的問題,進而提升測試的準確率。

引入相依套件與資料

首先,我們要引入會用到的 3 個套件,numpy、tensorflow、matplotlib,再從 tensorflow 的官方範例資料集下載 MNIST。

定義 Dropout 函式

因為我們使用的是 inverted dropout,所以在丟棄神經元後,為了讓輸出的期望值保持不變,會將所有的神經元皆除以保留的機率 keep_prob。

定義類神經網路模型

我們定義一個 784 x 1000 x 10 的類神經網路,因為 MNIST 的影像長寬為 28 x 28,所以輸入節點為 784 個,而隱藏層的神經元數量為 1000 個,最後為 10 個節點,經過 softmax 後輸出分別為 0 ~ 9 的機率。

在使用 dropout 時,要特別注意的是只有在訓練時需要 dropout 神經元,而測試時不使用 dropout,因此我們可以用 tf.placholder with_default() 定義保留神經元的的機率,並在測試使用預設 1.0 的機率保留所有的神經元。

定義超參數

在這個範例中,除了設定批量大小(batch size)、迭代次數(epochs)、步幅(learning rate) 之外,我們還要再定義保留神經元的機率 keep_prob,但是為了實現程式碼,所以將 keep_prob 這個超參數的定義移至上一個部份。

定義損失函數、優化器與計算準確率的函數

在這個範例中,我們使用 softmax_cross_entropy_with_logits 這個函式,如果想要自己實做 cross entropy 與 softmax,需要特別注意 softmax 可能最造成數值爆炸的問題,因為一旦遇到像是 e⁵⁰ 這種數值,Python 無法容納如此龐大的數值,因此會發生 nan (not a number)。

訓練類神經網路模型

為了實驗 dropout 是一個有效的方法,在訓練時分別對沒有使用 dropout使用 dropout 的類神經網路都訓練 10 次,並取得各自的準確率作為比較的基準。

在這個範例中使用的是小批量的訓練方法,而 MNIST 的訓練集為 55000 筆資料,所以每一個次的迭代 (epoch) 總共訓練 55000/200=275 次。

繪製準確率比較圖

最後,我們用 matplotlib 視覺化排序後的準確率,可以發現使用 dropout 的類神經網路比沒有使用 dropout 的類神經網路有更好的準確率,所以從實驗中可以證實 dropout 是一個可以對抗過擬合的方法,提升模型的泛化能力。

結論

從實驗中可以發現 dropout 是一個有效降低過擬合的方法,與權重衰減的精神相似,皆是在訓練模型時,不過度依賴某些權重,進而達到正則化的效果。

但是要特別注意的是,dropout 只能在訓練時使用,所以會造成測試時向前傳播的訊息大於訓練時向前傳播的訊息,通常會在測試時乘以 1−p 這個問題。而在這個範例中,我們使用的是 inverted dropout,在訓練時將傳遞的訊息除以 1−p,讓實作更為簡潔。

參考資料

延伸閱讀

--

--

Leo Chiu
手寫筆記

每天進步一點點,在終點遇見更好的自己。 Instragram 小帳:@leo.web.dev