簡介 Tensorflow Serialization — Part 2
之前的 Part 1。
我在這篇會簡單的介紹一下怎麼把自己訓練好的 model 儲存成單一一個 .pb 檔 (就像你在 Tensorflow Model Zoo 看到的模型那樣)。
本來在煩惱說用 MNIST 訓練、存模型寫到有點兒煩了,再拿它做範例有點兒沒挑戰性,結果看到這個 machrisaa/tensorflow-vgg,它把別人從 Caffe 轉到 python 的 code 再進一步用 numpy array (*.npy) 來存。如果你是 numpy 的愛好者,我是覺得可以用這個 repo 即可,但就想說都已經用 Tensorflow 了,應該再把它包得更漂亮點。所以啦,我這篇不做啥特別的事,就是想辦法把這裡面的 VGG Graph 拉出來存成 .pb 檔,這樣就能用我在之前 Part 1 介紹的方式把這 VGG 16/19 的 Graph 拿來用。
先來看看 code:
有些檔案有點兒長,你可以到這個 gist 複製我的 code。
嗯…本來以為有點兒挑戰,結果稍微看一下 repo 裡的 code 之後,大概只加了幾行 code 就寫完了。因為心情不夠爽快,所以只好再多寫一點兒。
Trainable Model To Model for Inference
其實我追 Tensorflow 裡有關 Serialize/Deserialize 的 code,大致上使用的情境跟步驟應該是像這樣:
- 我訓練好一個 model 了,訓練過程中為了 debug 存了不少 check point 檔案 (.chkp)。
- 把我確定要的 check point 檔載回 graph 中。
- 把 graph 中的
Variable變成constant(因為在使用模型時不需要訓練),稱之為 freeze graph 吧。 - 把 freeze graph 存成
.pb檔給人用。
其實步驟 4. 經過我這兩次簡介應該算是駕輕就熟了,但我還是寫了一個範例讓你從一個自己訓練好的模型,怎麼用 Tensorflow 提供的 API 完成步驟 1. 至 3.。
範例 code 如下:
步驟 1 : 儲存 Session
這部很單純,生成一個 tf.train.Saver 物件,在 Session 裡使用 saver 下的 save 方法指定存檔路徑。
存檔完後,在指定的路徑下應該會有以下檔案
1. XXX.data
2. XXX.index
3. XXX.meta步驟 2 : 回復 Saver 與 Graph
有了步驟 1 生成的檔案後,使用 tf.train.import_meta_graph 得到與訓練時相同狀態的 saver 與 graph 。
步驟 3 : 回復 Session
使用步驟 2 得到的 Saver 回復 Session (tf.train.Saver.restore 方法),之後把回復的 graph 轉換成 GraphDef (tf.Graph.as_graph_def 方法)。
步驟 4 : “Freeze” Graph
這裡我們會需要用到 tf.python.framework.graph_util模組裡提供的函數 convert_variables_to_constants 。需要注意的是,你必須提供你 graph 理想輸出的 Tensor 的所有名字。其實稍微瞄過原始碼後,就知道之所以要提供這個名單的原因是,Tensorflow 會以這些 node 為起點,找出整個 Graph 中與這些 node 相關的 SubGraph 是多少,然後把這個 SubGraph 裡的變數轉成常數輸出成一個 “freeze” GraphDef。
拿到 freeze 的 GraphDef 之後,事情就單純了。用 protobuf 把它 Serialize 起來存就完工啦!
細節就麻煩各位自己看 code 啦。
結果我還是直接用 MNIST 當範例了…Orz
沒辦法嘛,它就內建 Tensorflow 裡,用起來超方便的囉。原諒我吧 XD