深入解析 TensorFlow 2.0 儲存與載入模型的各種方法

Ticking ShiHao
12 min readJan 8, 2020

--

存取模型參數是資料科學工作者的基本功

在這篇教程我們將會學到如何儲存模型,並讀取模型參數。TensorFlow 提供了 2 種存取模型的方法,一種是 TF 原生 SavedModel,另一種是透過 Keras API 存成 HDF5 檔案格式。 我們將使用 mnist 數字辨識資料集進行示範,訓練一個簡單的 Keras 模型,並以不同的檔案格式存取。

文章將包含以下內容:

  1. 用 Keras 儲存 HDF5 檔案格式的模型
  2. 以 TensorFlow SavedModel 格式儲存模型
  3. 讀取模型

讓我們開始吧 ~ !

載入模組

  • numpy: 專門處理矩陣、向量的套件。
  • matplotlib.pyplot: 繪圖套件。
  • tensorflow: 機器學習套件,可訓練模型、存取模型。
  • tensorflow_datasets: 用以下載公開資料集 mnist。
  • tensorflow.keras: 專注於神經網路的機器學習套件,使用上比 tensorflow 更簡單直觀。

載入資料集

這次我們使用 mnist 資料集,也就是手寫數字辨識的資料集,訓練資料共 60,000 筆,測試資料 10,000 筆。這邊我們使用了 tfds API 來下載官方開源資料。下方程式碼用到 as_supervised=True,代表回傳的資料格式為 (image, label),如果將 as_supervised 設定為 False,回傳資料為字典格式 {'image': tf.Tensor, ‘label’: tf.Tensor}。

mnist 資料集

資料前處理

原始資料的數值型態為 uint8,數值分布範圍是 0~255,必須先調整到 0~1 之間的浮點數,才能輸入神經網路加以訓練,因此在資料前處理的第一步,我們建立 format_image 函數,將數值轉成浮點數 (float32),並除以 255.0。

下面用到了許多 tf.data.Dataset 資料處理的方法,這裡一一解釋:

  • cache: 將資料全部載入快取記憶體,可加速模型訓練
  • shuffle: 洗牌以打亂資料順序
  • map: 使用某個函數來處理資料,這裡用到的是自己寫的 format_image 函數
  • batch: 將數個資料組成一個批次,這裡一批包含 32 個樣本: (image, label)
  • prefetch: 在 GPU 或 TPU 訓練資料的過程,預先用 CPU 準備好下一批資料,可加速訓練過程

建立模型

我們使用 tf.keras.Sequential 來建立模型,中間使用了 3 層卷積層 (Convolution neural network) 搭配池化層 (Max pooling layer),依序用了 16, 32, 64 個 filters。批次資料經過卷積之後,使用 Flatten 函數將資料降到 2 維。再來使用全連接層 (Dense) 串接 512 個神經元,再接上分類器輸出 10 類別的機率,數字 0 ~ 9 總共 10 個數字。除了分類器使用 softmax 函數將數值轉換成機率之外,其餘層都使用 relu 做為活化函數。

訓練模型

本次教學目的在儲存與讀取模型參數,因此只訓練 3 個回合。

Epoch 1/3 1875/1875 [==============================] — 45s 24ms/step — loss: 0.0846 — accuracy: 0.9742 — val_loss: 0.0828 — val_accuracy: 0.9743 
Epoch 2/3 1875/1875 [==============================] — 40s 22ms/step — loss: 0.0603 — accuracy: 0.9813 — val_loss: 0.0839 — val_accuracy: 0.9740
Epoch 3/3 1875/1875 [==============================] — 40s 22ms/step — loss: 0.0485 — accuracy: 0.9846 — val_loss: 0.0569 — val_accuracy: 0.9837

確認模型預測結果

模型訓練 3 個回合在測試集的精確度就達到 0.9837,看起來真的不錯,我們來確認一下訓練結果。

Labels: [3 4 9 6 0 9 4 1 4 7 9 3 1 8 3 1 4 1 3 6 1 1 8 4 5 4 8 3 2 4 8 4]
Predicted labels: [3 4 9 6 0 9 4 1 4 7 9 3 1 8 3 1 4 1 3 6 1 1 8 4 5 4 8 3 2 4 8 4]

100% 正確,可以說模型在這個 batch 的預測結果相當的好呢!

儲存模型 - Keras HDF5 格式

現在我們已經訓練完模型,要利用 Keras 將模型存成 HDF5 格式,附檔名會是 '.h5',是 Keras 模型默認的格式,程式碼非常簡單只有一行。

讀取 .h5 模型檔

接者,我們來試試看讀取 .h5 模型檔,並調用 .summary() 來看一下模型結構是否跟原本建好的模型一樣。

Model: "sequential" _________________________________________________________________ Layer (type)                 Output Shape              Param #    ================================================================= conv2d (Conv2D)              (None, 26, 26, 16)        160        _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 13, 13, 16)        0          _________________________________________________________________ conv2d_1 (Conv2D)            (None, 11, 11, 32)        4640       _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 5, 5, 32)          0          _________________________________________________________________ conv2d_2 (Conv2D)            (None, 3, 3, 64)          18496      _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 1, 1, 64)          0          _________________________________________________________________ flatten (Flatten)            (None, 64)                0          _________________________________________________________________ dense (Dense)                (None, 512)               33280      _________________________________________________________________ dense_1 (Dense)              (None, 10)                5130       ================================================================= Total params: 61,706 
Trainable params: 61,706
Non-trainable params: 0

另外也檢查一下模型的預測結果跟原本的預測結果有沒有差異。我們把原始模型的預測值與載入模型的預測值進行比較,將兩者相減取絕對值,如果兩者預測值完全相同,代表模型參數一模一樣,預測結果也完全相同,則相減後每一項都是0。

Outputs:

0.0

繼續訓練模型

在實務上,新的資料會不斷地進入資料庫,模型也會慢慢衰老,這將導致模型的預測效果會愈來愈差,這時我們必須載入舊的模型參數,以新的數據加以訓練,更新我們的模型,確保模型的表現符合預期。

輸出成 TensorFlow SavedModel

另一種儲存模型的方式是將模型參數存成 TensorFlow SavedModel 格式,裡面除了包含訓練好的參數之外,TF 的計算圖與程式邏輯都包含在這個檔案格式裡了。存成 SavedModel 格式的應用面會更廣,除了原本 python 程式可以使用之外,手機應用 TFLite、TensorFlow.js、TensorFlow Serving、TFHub 都可以使用,增加了程式佈署在不同應用的能力,可以說是非常的人性化。

以 SavedModel 格式儲存,會產出以下三種檔案及資料夾:

  1. assets 資料夾: TF 計算圖 (graph) 會用到的檔案都包含在 assets 裡面,比如自然語言處理會用到的單字表。
  2. variables 資料夾: TF checkpoint
  3. saved_model.pd 檔案: TF 程式、模型、TF 變數、以及 Tensor 等

這裡調用到 tf.saved_model.save 方法,第一個參數輸入要儲存的物件,第二個參數輸入存檔路徑。

讀取 TensorFlow SavedModel

讀取模型的法也很簡單,指定讀取路徑,一行程式碼就可以載入模型。

讀取模型檔後,來檢查預測值是否與原始模型相符。

Outputs:

0.0

用 Keras 讀取 TensorFlow SavedModel

SavedModel 並非 Keras 的物件,但是 Keras 有很多很好用的函數,比如 .fit(), .summary(), .predict() 等等,如果想要讀取 SavedMode 成為 Keras 的物件,也是做得到的。實務上,佈署應用程式我們會存成 SavedModel,當覺得模型表現愈來愈差時,會再載入成 Keras 物件,以 .fit() 函數更新模型,最後再存成 SavedModel 格式重新佈署應用程式。

Outputs:

Model: "sequential" _________________________________________________________________ Layer (type)                 Output Shape              Param #    ================================================================= conv2d (Conv2D)              (None, 26, 26, 16)        160        _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 13, 13, 16)        0          _________________________________________________________________ conv2d_1 (Conv2D)            (None, 11, 11, 32)        4640       _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 5, 5, 32)          0          _________________________________________________________________ conv2d_2 (Conv2D)            (None, 3, 3, 64)          18496      _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 1, 1, 64)          0          _________________________________________________________________ flatten (Flatten)            (None, 64)                0          _________________________________________________________________ dense (Dense)                (None, 512)               33280      _________________________________________________________________ dense_1 (Dense)              (None, 10)                5130       ================================================================= Total params: 61,706 
Trainable params: 61,706
Non-trainable params: 0 _________________________________________________________________

總結

TensorFlow 提供兩種模型存檔的方式,一種是存成 Keras HDF5 格式,副檔名是 .h5,另一種是存成 SavedModel,在這篇教程我們學會了這兩種方法保存與讀取模型。

如果你喜歡我的文章,請幫我按 Clap 給予我支持,也歡迎將文章分享到其他社群,我們下次見囉,掰掰~

--

--

Ticking ShiHao

一名在化工業苦蹲的資料科學工作者,興趣是當AI神棍。