使用 TensorFlow 了解 overfitting 與 underfitting

Leo Chiu
手寫筆記
Published in
8 min readAug 16, 2018

前言

假設你要想利用機器學習辨識手寫文字,在訓練時得到的訓練誤差很低,但是訓練完的模型辨識文字的效果卻不好,這是什麼原因呢?

在這篇文章,我們會用 TensorFlow 建立一個模擬函數的模型,透過修改模型的容量與超參數,特意讓模型發生過擬合與欠擬合,並觀察究竟在何種情況會發上以上兩種況。

訓練誤差與測試誤差

首先,我們先來重新溫習何謂訓練誤差測試誤差。在訓練模型時,我們會使用損失函數評估預測值與真實值的誤差,該誤差稱作訓練誤差 (training error);在訓練完模型後,任一訓練資料以外的資料用於訓練出的模型所產生的誤差稱作測試誤差 (testing error)

過擬合與欠擬合

假設我們想要訓練模型辨識手寫文字,因此建立一個多層感知機的模型,在訓練時會用訓練誤差與測試誤差評估模型的泛化能力。如果訓練誤差總是無法降低,預測的準確率很低,我們會稱這種情況為欠擬合 (underfitting);而如果訓練誤差很低時,在訓練資料的表現很好,但是卻在在測試集上無法獲得較好的結果,則會稱這種情況為過擬合 (overfitting)

接著,我們要來談談發生過擬合與欠擬合的兩個主要原因:模型容量訓練的資料多寡

何謂模型容量?

在機器學習的領域中,我們會說容量 (capacity) 是模型擬合數的能力。通常,越是複雜的模型其容量越高,但是容易造成過擬合。反之,越是簡單的模型其容量越低,卻容易導致欠擬合。

以下的函數是線性迴歸的模型,其能夠擬合的函數不外乎就是線性函數,如果資料成非線性的分布,則線性回歸的模型較難擬合資料的規律,因此該模型的容量較低。

如果是一個 9 次方的函數模型,相較於線性回歸模型,它較有能力擬合非線性的資料分佈,因此該模型的容量較大。

在選擇模型時,必須選擇適當容量的模型,避免發生過擬合或欠擬合。

訓練資料多寡

機器學習的基本條件之一便是資料量,如果想要訓練能夠辨識手寫文字的模型時,每一種文字只有寥寥無幾的資料數量,儘管經過長時間的訓練,一旦遇到風格非常不同的手寫文字,模型便會失準。

因此,如果資料過少,則訓練出的模型容易發生過擬合的結果。

模擬過擬合與欠擬合

接著,我們嘗試模擬過擬合與欠擬合發生的情況,將使用多項式回歸 (polynomial regression) 擬合以下函數,用擬合的結果解釋過擬合與欠擬合。

overfitting 與 underfitting

引入所需相依套件

首先引入在這個範例中會使用到的套件,我們需要使用 numpy 來產生資料,再者,利用 TensorFlow 建立多項式回歸的模型,學習資料隱藏的規律,最後,將會用 matplotlib 視覺化模型擬合的情況。

產生資料

我們使用 numpy 產生 200 筆在常態分佈中的數值作為 features,並將這些數值帶入目標函數生成 labels,並將 features 與 labels 各分一半作為訓練資料與測試資料。在生成 labels 時,我們加入一些噪音資料,讓資料的分佈更符合現實中的情況。

建立模型

欲擬合的函數共有兩個變數 y 與 x,所以我們要定義兩個 tf.placeholder,在之後訓練時傳入資料。

在定義運算時,首先定義多項式回歸模型的最高次方 e,並藉由 e 動態決定係數w 的個數;接著,將運算各別加入串列中,最後再經由 tf.add_n 一次性地加入所有的運算至 net

多項式回規模型與線性回歸模型有些類似,如果不熟悉如何建立模型可以參考使用 TensorFlow 學習線性回歸 (Linear Regression)

設定超參數

接下來,我們有三個超參數要設定,分別是 learning rate、batch size、epochs,以上三個超參數你都可以嘗試調整成不同的數值,觀察模型擬合函數的結果。

設定損失函數與優化器

在這個範例中,使用均方誤差 (Mean Squared Error, MSE) 作為損失函數,並使用梯度下降法優化參數。

建立迭代器

tf.data.Dataset 在 TensorFlow 1.4 的版本中,從 tf.contrib 移至核心 API,是一個可以用來迭代資料的函式庫。我們使用 tf.data.Dataset 作為迭代資料的工具。

訓練模型

在開始訓練之前,別忘記要初始化所有的變數以及用來迭代資料的迭代器。我們迭代的次數 (epochs) 為 100 次,可以嘗試調整看看這個超參數,觀察會發生什麼事情。

因為我們使用 with-as 開啟 Session,所以在 with-as 裡面的程式碼結束後,Session 會自動關閉。如果你是初學者,可能在訓練時會發現,一旦再重新開啟一個 Session,此時原本訓練好的參數都不見了,所以我們在訓練完模型時,必須要保存模型的參數

保存模型的參數可以使用 tf.train.Saver(),所有在計算圖中的資訊都會被儲存在 checkpoint (.ckpt 檔) 中,當我們再次啟動 Session 時,可以載入 checkpoint 回復之前的模型參數與狀態。

視覺化模型擬合函數

首先載入模型參數與狀態,接著用 matplotlib 視覺化模型擬合函數的結果。我們分別視覺化 2 次多項式模型、3 次多項式模型與 4 次多項式模型和真實資料擬合的狀況。

欠擬合 (underfitting)

我們視覺化 2 次多項式模型擬合函數的結果,發現的容量過低的模型不易擬合函數;再從右圖可以看到訓練誤差與測試誤差都沒有收斂,所以我們可以判定模型發生了欠擬合。當模型欠擬合時,解決方法是增加模型的容量,就可能可以解決欠擬合。

欠擬合(Underfitting):二次多項式模型

正常擬合

接下來我們看到三次多項式的模型可以幾乎很完美地擬合函式,從右圖看到訓練誤差與測試誤差非常地接近,而且隨著迭代次數增加,模型慢慢地收斂,選擇適當的模型可以順利地擬合函數。

正常:三次多項式模型

過擬合 (overfitting)

容量大於實際上需要解決的問題時,容易發生過擬合的問題。以下的例子是四次多項式的模型,雖然模型隨著迭代增加訓練誤差越來越小,但是訓練誤差與測試誤差仍相差一段距離,可以判斷該模型發生了過擬合的問題。

過擬合(overfitting):四次多項式模型 - 容量過大

另外一種會發生過擬合的情況是訓練資料不足,我們刻意調整訓練資料的數量至只有 10 筆,當我們利用三次多項式擬合函數時,模型會發生過擬合。

過擬合(overfitting):三次多項式模型 - 訓練資料不足

結論

從這篇文章的例子中,你們會發現模型的容量資料量的多寡皆會影響訓練的結果,所以在訓練模型前必須收集足夠的資料與選擇適當的模型,藉此避免過擬合與欠擬合。

除了調整模型的容量與資料量以外,未來我們會再談談對抗過擬合的正則化 (regularization)方法。

--

--

手寫筆記
手寫筆記

Published in 手寫筆記

學習永無止盡,我們一起學習。

Leo Chiu
Leo Chiu

Written by Leo Chiu

每天進步一點點,在終點遇見更好的自己。 Instragram 小帳:@leo.web.dev