簡介 Tensorflow Serialization — Part 1

昨晚追了一下 Tensorflow 的 code ,趁記憶猶新想說來寫個簡介留個紀錄,以後自己忘了可以看 XD。那廢話說到這裡,咱們來看看 Tensorflow 是怎麼做 Serialization 的。我這邊會假設你已經知道 ProtocolBuffer 是什麼也已經知道一些基本的使用方式,如果你不知道的話,可以參考官方文件或者是我之前寫的一篇簡介

所謂的 Serialization ,說穿了就是,假想你用 Tensorflow 寫好也訓練好了一個神經網路之後,怎麼把這個訓練好的模型存起來,方便下次直接使用或在其他平台上使用 (例如嵌入式系統上) 。

剛開始用 Tensorflow 的時候,因為也不知道 protobuf 在幹啥,所以想要把模型的參數取出來時,都直接使用 tf.trainable_variables() 拿到所有參數的 Tensor ,然後在 Session 裡把它 evaluate 之後另外存,大致上像這樣:

import tensorflow as tf
def build_graph():
... buid graph here
def restore_graph(weights):
... restore graph with given weights
def save_weights(weights, out_fname):
... saving weights to a file
def read_weights(fname):
... read weights from a file
graph = build_graph()
with Session(graph=graph) as sess:
.... training training training
variables = tf.trainable_variables()
weights = sess.run(variables)
save_weights(weights, "my_weights.model")
restored_weights = read_weights("my_weights.model")
new_graph = restore_graph(restored_weights)
... do stuff

這樣做是很直 (ㄘㄨ) 接 (ㄌㄨˇ) 的做法。優點當然是…很好懂 XD,當初跟 Tensorflow 還沒那麼熟的時候這樣做確實是可以動,以一個懶惰工程師來說,不用看多少文件 code 就能動當然很開心啦!但缺點是,身為懶惰工程師,發現這樣做之後一整個要寫的 code 就變多了。除了要有 build graph 的 code 要寫,還有 read/write weights 的函數,更別提如果牽扯到部署的話,你自己寫的這些 code 全部都要部署上去才會動。結論:惡夢。

所以,為了增 (ㄐㄧˋ) 進 (ㄒㄩˋ) 效 (ㄊㄡ) 率 (ㄌㄢˇ),最好的方法就是透過 Tensorflow 提供的 API 進行模型的存取會是最好的,而 Tensorflow 是透過 protobuf 做高效能的模型存取。


進入正題啦!

Pre-Trained Model

之前就聽說 Tensorflow 有提供的 pre-trained model,今天算是有機會試試了。Tensorflow 的 repo 逛來逛去,覺得 object detection 這東西我比較熟悉,所以就決定拿它當例子了。它本身也有一個很棒的 ipython notebook 介紹這模型的用法,我主要也是從這裡繼續看下去的。

本來有想講講這些 .pb 檔是怎麼做出來的,順便提提 Tensorflow 的原始碼裡相關的 .proto 檔,但後來發現這需要說的有點兒多,所以這邊就先專注在怎麼把這些 .pb 檔 load 進來。之後再補一篇怎麼自己生出 .pb 檔。希望自己不會只是說說而已 XD。

Tesorflow Detection Model Zoo 看來 ssd_inception_v2_coco 看起來是最快的,所以就決定試玩它了。咱們直接看 code:

你可以從這裡看到所有的原始碼。

要載入別人提供的 .pb 檔到你的 Graph 裡,大致步驟如下:

  1. 使用 tf.Graph.as_default 將想要載入的 Graph 設為 default graph 。
  2. 使用 tf.GraphDeftf.Graph.as_graph_def 取得 GraphDef 物件。
  3. 讀取 .pb 檔內容 (binary mode),並指用 GraphDef.ParseFromString 解析讀取的內容。
  4. 使用 tf.import_graph_def 將步驟 3 裡的 GraphDef 物件載入 default graph 中。

從這裡不難看到,Tensorflow 主要是透過 tf.GraphDef 這個物件去載入模型。這時你跑去看了電腦上 tf.GraphDef 的原始碼,那原始碼簡直不像是人寫出來的….

你或許會問,為什麼要這樣多此一舉?code 還那麼醜!

其實如果你對 ProtocolBuffer 這東西熟悉的話,看到第 3 步的 GraphDefParseFromString 應該就已經猜到了,這原始碼確實不是人寫的,而是透過 ProtocolBuffer Compiler 生成的 (我有寫個簡單的簡介)。所以說,雖然有習慣看原始碼是件好事,但這裡就先算了吧。

一旦載入成功,後面的事就還蠻單純的,大多就是看看文件或 paper 找說實作裡的 input/output tensor 叫啥名字之類的,然後把你要跑的資料塞進去跑就對了。

除了看一堆文件或 code 之外,也可以使用 tensorboard 視覺化載入的模型。這樣應該比看文件跟 code 直觀,有空的話,我再寫怎麼用吧 XDD。

希望之後閒一點可以把怎麼自己生出 .pb 檔給別人用和 tensorboard 的簡介寫一寫,希望啦。

如果有說錯的地方,就請不吝賜教啦。m(_ _)m


Part 2 在這裡。

One clap, two clap, three clap, forty?

By clapping more or less, you can signal to us which stories really stand out.