Tensorflow model file format

Ran
Ran ( AI Deep Learning )
9 min readDec 9, 2019

這篇文章,介紹 Tensorflow 模型文件

  • TensorFlow 模型格式簡介
  • Checkpoint (*.ckpt) 介紹
  • GraphDef (*.pb) 介紹
  • SavedModel 介紹
  • Tensorflow 模型文件格式轉換

TensorFlow 模型格式簡介:TensorFlow 的模型格式有很多種,針對不同場景可以使用不同的格式,只要符合規範的模型都可以輕易部署到線上服務或移動裝置上,這裡簡單列舉。

* Checkpoint: 用於儲存模型的權重,主要用於模型訓練過程中引數的備份和模型訓練熱啟動。* GraphDef:用於儲存模型的Graph,不包含模型權重,加上checkpoint後就有模型上線的全部資訊。* ExportModel:使用 exportor 介面匯出的模型檔案,包含模型 Graph 和權重可直接用於上線,但官方已經標記為 deprecated 推薦使用 SavedModel。* SavedModel:使用 saved_model 介面匯出的模型檔案,包含模型 Graph 和許可權可直接用於上線,TensorFlow 和 Keras 模型推薦使用這種模型格式。* FrozenGraph:使用 freeze_graph.py 對 checkpoint 和 GraphDef 進行整合和優化,可以直接部署到 Android、iOS 等移動裝置上。* TFLite:基於 flatbuf 對模型進行優化,可以直接部署到 Android、iOS 等移動裝置上,使用介面和 FrozenGraph 有些差異。
  • 模型格式
目前建議 TensorFlow 和 Keras 模型都匯出成 SavedModel 格式,這樣就可以直接使用通用的 TensorFlow Serving 服務,模型匯出即可上線不需要改任何程式碼。不同的模型匯出時只要指定輸入和輸出的 signature 即可,其中字串的 key 可以任意命名只會在客戶端請求時用到,可以參考下面的程式碼示例。注意,目前使用 tf.py_func() 的模型匯出後不能直接上線,模型的所有結構建議都用 op 實現。
  • TensorFlow 模型匯出
  • Keras 模型匯出
  • SavedModel 模型結構
使用 TensorFlow 的 API 匯出 SavedModel 模型後,可以檢查模型的目錄結構如下,然後就可以直接使用開源工具來載入服務了。

Checkpoint (*.ckpt) 介紹:Tensorflow 模型包含兩個內容。1. 神經網絡的結構圖 graph。已訓練好的變量參數

  • Tensorflow 模型包含兩個文件:
1. meta graph:保存 tensorflow 網絡圖,副檔名為 .meta2. checkpoint file:檔案為二進制文件,包含權重變量,biases 變量和其他變量,文件副檔名為 .ckptcheckpoint 
model.ckpt-240000.data-00000-of-00001
model.ckpt-240000.index
model.ckpt-240000.meta
checkpoint:列出保存的所有模型以及最近模型的相關信息
.data:包含訓練變量的文件
.index:描述 variable 中 key 和 value 的對應關係
  • 導出 ckpt 文件:
訓練完成後,為保存所有的變量和網絡結構,需使用:tf.train.Saver()需注意,tensorflow 變量的作用範圍是在一個 session 裡面,所以在保存模型的時候,應該在 session 裡面通過 save 方法保存。import tensorflow as tf 
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')
# --------------------------------------------------------------
如果希望在迭代 1000 次之後保存模型,可以把當前的迭代步數傳進
saver.save(sess, 'my_test_model',global_step=1000)訓練時,假設每 1000 次就保存一次模型,但是這些保存的文件中變化的僅是神經網絡的 variable,而網絡結構沒有變化,不需要重復保存 .meta 文件所以可以設置只讓網絡結構保存一次:saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)# --------------------------------------------------------------
如只想保留最新的 4 個模型,並希望每 2 個小時保存一次,可以使用 max_to_keep 和 keep_checkpoint_every_n_hours:
#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
# --------------------------------------------------------------
如果沒有在 tf.train.Saver() 指定任何參數,這樣表示默認保存所有變量。
如果不希望保存所有變量,而只是其中的一部分
此時可以指點要保存的變量或者集合:
只需在創建 tf.train.Saver 時把一個列表或者要保存變量的字典作為參數傳進去。
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)
  • 導入 ckpt 文件:
1. 從 .meta 文件導入原始網絡結構圖:saver = tf.train.import_meta_graph('my_test_model-1000.meta')# --------------------------------------------------------------
加載了網絡結構圖之後還需要加載變量數據
2. 加載變量:使用 restore() 方法恢復模型的變量參數with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./'))
# --------------------------------------------------------------
在此之後, w1 和 w2 的 tensor 已經恢復:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('my-model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
print(sess.run('w1:0'))
# Model has been restored. Above statement will print the saved value of w1.
  • 從 ckpt 文件恢復訓練模式:
恢復任何預先訓練的模型,並用它進行 inference,fine-tuning 或者進一步訓練。
在 tensorflow 中,如果有佔位符,那麼就需要將數據傳入佔位符中,
但是當保存 tensorflow 模型的時候,佔位符的數據不會被保存的(佔位符本身的變量是被保存的)
# --------------------------------------------------------------
import tensorflow as tf

# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

# Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()

# Run the operation by feeding input
print sess.run(w4,feed_dict)
# Prints 24 which is sum of (w1+w2)*b1

# Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
所以需要恢復它時,不僅要恢復網絡結構和相關變量參數,還需要準備新的 feed_dic(數據) 傳入佔位符中。通過 graph,get_tensor_by_name() 方法可以恢復所保存的佔位符和 opertor。比如下面的 W1 是一個佔位符,op_to_restore 是一個算子。# How to access saved variable/Tensor/placeholders
w1 = graph.get_tensor_by_name("w1:0")

# How to access saved operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
  • 從 ckpt 文件恢復訓練模式,並修改模型結構:

從 Tensorflow-ckpt 模型檔案,解析並顯示網路結構圖

如何從 CKPT 模型檔案中提取網路結構圖並實現視覺化。
理論上,既然能從 pb 模型檔案中提取網路結構圖,CKPT 模型檔案自然也不是問題,但是其中會有一些問題。
# --------------------------------------------------------------
1. 解析 CKPT 網路結構
解析 CKPT 網路結構的第一步是讀取 CKPT 模型中的圖檔案,得到圖的 Graph 物件後即可得到完整的網路結構。讀取圖檔案示例程式碼如下所示。
呼叫 graph.get_operations() 後即可得到當前圖的所有計算節點,在利用 Operation 物件與 Tensor 物件之間的相互引用關係即可推斷網路結構。但是需要注意的是,從 meta 檔案中匯入的圖中獲取計算節點存在如下問題。* 包含反向梯度下降計算的所有節點
* 某些計算節點是按基礎計算(加減乘除等)節點拆分成多個計算節點的,
如BatchNorm,但其實是可以直接合併成一個節點的。
pb 模型檔案可以避免上面第一個問題,將 CKPT 模型轉 pb 模型後,可以自動將反向梯度下降相關計算節點移除。對於第二點,pb 模型檔案會自動將基礎計算組成一個計算節點,但是對於 Tensor 操作的函式如 Slice 等函式是無法合併的。因此,對於第 2 個問題,將 CKPT 模型轉 pb 模型後,可以減少這類問題,但是無法避免。徹底避免的方法只能通過自己針對性地實現。經過以上分析,得出的結論是非常有必要將CKPT模型轉pb模型。# --------------------------------------------------------------
2. 自動將 CKPT 轉 pb,並提取網路圖中節點
# --------------------------------------------------------------
3. 測試
以 MobileNet V1 網路結構為例,
下載 MobileNet_v1_1.0_192 檔案並壓縮後,得到
mobilenet_v1_1.0_192.ckpt.data-00000-of-00001、mobilenet_v1_1.0_192.ckpt.index、
mobilenet_v1_1.0_192.ckpt.meta檔案。
我們還需要知道 mobilenet_v1_1.0_192.ckpt 模型對應的輸入和輸出 Tensor 物件的名稱,官方提供的壓縮包檔案中並沒有告知。一種方法是執行官方程式碼,把輸入 Tensor 的名稱打印出來。
但是執行官方程式碼本身就需要一定的時間和精力,
底下有介紹「從 Tensorflow-pb 模型檔案,解析並顯示網絡結構圖」
其程式碼實現中已經實現了將原始網路結構對應的字串寫入到 ori_network.txt 檔案中。
因此,可以先隨意填寫輸入名稱和輸出名稱,待生成ori_network.txt檔案後,從檔案中可以直觀看到原始網路結構。ori_network.txt檔案部分內容如下所示。
通過該檔案可知,
輸入 Tensor 的名稱為:batch:0
輸出Tensor名稱為:MobilenetV1/Predictions/Reshape_1:0
有了這些資訊後,呼叫函式read_graph_from_ckpt得到靜態圖的節點列表物件ops,呼叫函式 gen_graph(ops,"save/path/graph.html") 後,在目錄 save/path 中得到 graph.html 檔案,開啟 graph.html 後,顯示結果如下。

GraphDef (*.pb) 介紹:此文件包含 protobuf 對象序列化後的數據,包含計算圖,可以從中得到所有運算符(operators)的細節,也包含 tensors

  • 如何有效的檢視已有的 pb 模型檔案
# --------------------------------------------------------------
# 重新載入模型檔案,並輸出定義
model = 'model.pb'
with tf.Session() as sess:
with open(model, 'rb') as model_file:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
print(graph_def)
# 採用上述的方式可以在新的會話中重新載入本地的模型檔案(pb),然後二進位制解析後,輸出可以看到結果。但是如果網路層結構十分複雜,那麼這種顯示方式就會比較難以閱讀。# --------------------------------------------------------------
# 重新載入模型檔案,並使用 Tensorboard 進行視覺化處理
from tensorflow.python.platform import gfile
model = 'model.pb'
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)
# 然後會在你的 log 資料夾下面生成檔案。
# 在終端中執行
tensorboard --logdir DIR --host IP --port PORT
# 一般情況下,不設定 host 和 port,就會在 localhost:6006 啟動。DIR是路徑(不加引號)# --------------------------------------------------------------
# 上面的例子:
tensorboard --logdir log然後在瀏覽器中訪問 localhost:6006 就可以視覺化你的網路結構
  • 從 Tensorflow-pb 模型檔案,解析並顯示網絡結構圖
# --------------------------------------------------------------
# Tensor 對象 與 Operation 對象
Tensorflow 中,
Tensor 對象主要用於存儲數據如常量和變量(訓練參數),
Operation對象是計算節點,如卷積計算、反卷積計算、ReLU等等。
每一個 Operation 對象均有輸入和輸出 Tensor,
同理,每個 Tensor 對象均有對應生成該 Tensor 的 Operation 對象和使用該 Tensor 對象作為輸入的 Operation 對象。Tensor 和 Operation 對象內均有相關屬性和函數來獲取其關聯的 Operation 和 Tensor 對象,相關屬性如下所示。1. Tensor 對象的 op 屬性指向生成該 Tensor 的 Operation 對象。
2. Tensor 對象的 consumers() 函數獲取使用該 Tensor 對象作為輸入的 Operation 對象。
3. Operation 對象的 inputs 屬性指向該計算節點的輸入 Tensor 對象。
4. Operation 對象的 outputs 屬性執行該計算節點的輸出 Tensor 對象。
如下圖所示的網絡結構中,
調用 Tensor_ 2對象的 consumers() 函數,返回的是 [op_1,op_2]。
Tensor_3 的 op 屬性指向的是 op_1。
op_1 的 inputs 屬性指向的是 [Tensor_1,Tensor_2],
op_1 的 output 屬性指向的是 [Tensor_3]。
有了 Tensor 與 Operation 對應在圖中的關聯關係,就可以將網絡結構給畫出來。
# --------------------------------------------------------------
# 測試模型顯示
文中介紹的 MobileNet V1 網絡結構為例,下載 MobileNet_v1_1.0_192文件並壓縮後,得到 mobilenet_v1_1.0_192_frozen.pb 文件。還需知道 mobilenet_v1_1.0_192_frozen.pb 模型對應的輸入和輸出 Tensor 對象的名稱,好在 MobileNet_v1_1.0_192 壓縮包中包含文件 mobilenet_v1_1.0_192_info.txt。通過該文件可知,
輸入 Tensor 的名稱為:input:0,
輸出 Tensor 名稱為:MobilenetV1/Predictions/Reshape_1:0。
有了這些信息後,調用函數 read_graph_from_pb 得到靜態圖的節點列表對象 ops,調用函數 gen_graph(ops,"save/path/graph.html") 後,在目錄 save/path 中得到 graph.html 文件,打開 graph.html 後,顯示結果如下。
MobileNet V1 網絡結構解析並展示效果
MobileNet V1 網絡結構解析並展示效果
  • CNNGraph 程式碼

MobileNet V1 官方預訓練模型

MobileNet V1 定義網絡結構文件,請見下列https://raw.githubusercontent.com/tensorflow/models/master/research/slim/nets/mobilenet_v1.py打開以上網址,可以看到 MobileNet V1 官方預訓練的模型,官方提供了不同輸入尺寸和不同網絡中通道數的多個模型,並且提供了每個模型對應的精度。可以根據實際的需要下載對應的模型這裡以選擇 MobileNet_v1_1.0_192 為例,表示網絡中的所有卷積後的通道數為標準通道數(即1.0倍),輸入圖像尺寸為 192X192。
從下列網址,下載《MobileNet_v1_1.0_192》http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192.tgz# --------------------------------------------------------------
如上下載 mobilenet_v1.py 文件後,使用其中的 mobilenet_v1 函數構建網絡結構靜態圖,如下代碼所示。
代碼中,使用函數 tf.nn.top_k 取概率最大的 3 個類別機器對應概率
# 加載模型參數
# 先定義 placeholder 輸入 inputs,再通過函數 build_model 完成靜態圖的定義。接下來傳入 tf.Session 對象到 load_model 函數中完成模型加載。
CKPT = 'mobilenet_v1_1.0_192.ckpt'
def load_model(sess):
loader = tf.train.Saver()
loader.restore(sess,CKPT)

inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3))
classes_tf,scores_tf = build_model(inputs)
with tf.Session() as sess:
load_model(sess)
# --------------------------------------------------------------
模型測試-加載Label
網絡輸出結果為類別的索引值,需要將索引值轉為對應的類別字符串。
先從官網下載label數據,需要注意的是 MobileNet V1 使用的是 ILSVRC-2012-CLS(http://www.image-net.org/challenges/LSVRC/2012/) 數據,因此需要下載對應的 Label 信息。解析 Label 數據代碼如下。def load_label():
label=['其他']
with open('label.txt','r',encoding='utf-8') as r:
lines = r.readlines()
for l in lines:
l = l.strip()
arr = l.split(',')
label.append(arr[1])
return label
# --------------------------------------------------------------
模型測試-測試結果
執行inference.py後,控制台輸出結果如下所示。
識別 test_images/test1.png 結果如下:
No. 0 類別: 軍用飛機 概率: 0.9363691
No. 1 類別: 飛機翅膀 概率: 0.032617383
No. 2 類別: 炮彈 概率: 0.01853972
識別 test_images/test2.png 結果如下:
No. 0 類別: 小兒床 概率: 0.9455737
No. 1 類別: 搖籃 概率: 0.044925883
No. 2 類別: 板架 概率: 0.007288801
# --------------------------------------------------------------
完整代碼,inference.py 完整的代碼如下

ckpt 模型轉換 pb 文件

雖然打包下載的文件中包含已經轉換過的 pb 文件,但是官方提供的 pb 模型輸出是 1001 類別對應的概率,需要的是概率最大的 3 類。
可在原始網絡中使用函數 tf.nn.top_k 獲取概率最大的 3 類,將函數 tf.nn.top_k 作為網絡中的一個計算節點。
模型轉換代碼如下所示。上面代碼中,單一的所有類別概率經過計算節點 tf.nn.top_k 後分為兩個輸出:
概率最大的 3 個類別 classes,
概率最大的 3 個類別的概率 scores。
執行上面代碼後,在目錄「model」中得到文件 mobilenet_v1_1.0_192.pb。

SavedModel 介紹:在使用 TensorFlow Serving 時,會用到此格式的模型。該格式為 GraphDef 和 CheckPoint 的結合體,另外還有標記模型輸入和輸出參數的 SignatureDef。從 SavedModel 中可以提取 GraphDef 和 CheckPoint 對象。

└────── 1
···├──── saved_model.pb
···└──── variables
······├──── variables.data-00000-of-00001
······└──── variables.index
pb 文件 + variable 目錄 (.index 文件 + .data)
如果從 pb 文件中轉出來的模型,variable 文件夾中為空,因為 p b文件裡面的各項參數都是 tf.constant,所以不會存儲到 variable 裡面。
  • 模型導出:
  • Tensorflow serving 服務的部署

Tensorflow 模型文件格式轉換:Tensorflow 模型的 graph 結構可以保存為 .pb 文件或者 .pbtxt 文件,或者 .meta 文件,其中只有 .pbtxt 文件是可讀的。

訓練好的網絡,往往會將模型保存為一個統一的 .pb 文件,
這個文件中不止保存著模型網絡的結構和變量名,還保存了所有變量的值,
如果想利用訓練好的模型對自己的數據進行測試,往往要對這個模型做一些修改。
這時需要知道原有模型裡面的一些張量名稱,但是 .pb 文件和 .meta 文件都是不可讀的,所以有必要對這兩種文件進行格式轉換。
  • .meta 文件
這種情況下,通常還需要其他幾個 checkpoint 文件
checkpoint
model.cpkt.index
model.cpkt.data 等
可以使用 tensofrflow 安裝目錄下的 /tensorflow/tensorflow/python/tools/inspect_checkpoint.py 文件打印輸出模型中所有張量(tensor)和操作(op)的名稱下面是 inspect_checkpoint.py 的全部代碼:
  • .pb 文件
下面的代碼定義了兩個函數,可以實現 .pb 文件和 .pbtxt 文件之間的轉換

TensorFlow 簡單 pb (pbtxt) 文件讀寫

在使用 TensorFlow 的時候,很多地方都會遇到 protobuf (.proto) 文件,比如配置 TensorFlow Detection API 的過程中需要執行如下語句:$ protoc object_detection/protos/*.proto --python_out=.表示把文件夾 protos 下的所有 .proto 文件轉化對應的 .py 文件。 之後,再借助這些轉化來的 .py 文件就可以讀取特定格式的 .pbtxt, .config 等文件了。比如,可以使用 string_int_label_map.proto
轉化來的 string_int_label_map_pb2.py 文件來讀取目標檢測與實例分割 的類名與類標號配置文件(這種文件以 .pbtxt 作為後綴名,假設要檢測 person 和 car 等類目標):
# --------------------------------------------------------------
item {
id: 1
name: 'person'
}

item {
id: 2
name: 'car'
}

...
# --------------------------------------------------------------
Protocol Buffer(protobuf) 是谷歌開源的一種數據存儲語言,每一個文件以 .proto 為後綴,它不依賴特定的語言與平台,擴展性極強。TensorFlow 內部的數據存儲(比如 .pb.ckpt 等)基本都使用 protobuf 格式。

下面以一個簡單的例子來說明 protobuf 文件的讀取。
  • protobuf 數據結構定義
假設我們的目的是讀取如下文件 (命名為 students.pbtxt,使用文本編輯器編輯):# --------------------------------------------------------------
student_info {
name: 'Zhang San';
age: 20;
sex: 0;
}

student_info {
name: 'Li Si';
age: 25;
sex: 0;
}

student_info {
name: 'Wang Wu';
age: 18;
sex: 1;
}
# --------------------------------------------------------------
顯然,這是一份簡單的結構化數據,但若使用傳統的數據讀取方式且要快速方便的解析出其中的學生信息,卻不容易。此時,如果使用 protobuf 則相當便捷。
首先,用文本編輯器編輯一個 .proto 文件 (命名為:student_info.proto)# --------------------------------------------------------------
syntax = "proto3";

package proto_test;

message Student {
string name = 1;
int32 age = 2;
int32 sex = 3;
}

message StudentInfo {
repeated Student student_info = 1;
}
# --------------------------------------------------------------
其中的 syntax 指定使用的 protobuf 版本,可以填寫 proto2 (protobuf 2) 和 proto3 (protobuf 3),這裡使用的是後者。
下面的 package 指定 student_info.proto 文件所在的文件夾名字。接下來,以關鍵字 message 開頭定義了一個簡單的數據結構 Student,裡面包括三個可選字段 name,age 和 sex,後面的數字 1,2,3 指定這三個字段在編碼序列化後的二進制數據中的順序,因此在同一個 message 內部,它們是不允許重復的。如果是 protobuf 2 的版本,需要在可選字段前面加上 optional 關鍵字 (關鍵字包括 optionalrequiredrepeated)定義好 Student 結構之後, students.pbtxt 文件的內容基本是重復這個結構,因此還需要定義一個新的 message StudentInfo,它包含一個可重復的 (repeated) 字段 student_info。至此,數據解析格式定義完了,接下來要將它轉化為 Python 格式語言,執行:$ protoc proto_test/*.proto --python_out=.# --------------------------------------------------------------
會自動生成一個 student_info_pb2.py 文件,前幾行如下:
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: proto_test/student_info.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()DESCRIPTOR = _descriptor.FileDescriptor(
name='proto_test/student_info.proto',
package='proto_test',
syntax='proto3',
...
# --------------------------------------------------------------
基於這個 student_info_pb2.py 就可以方便的解析 students.pbtxt 中的內容了。
  • 讀取 .pbtxt 文件
有了轉化來的 student_info_pb2.py 文件,解析 students.pbtxt 就輕而易舉了,代碼如下 (命名為:read_student_info.py)
首先, 使用 tf.gfile.GFile 將指定文件讀入成字符串,然後定義一個 student_info_pb2.StudentInfo() 結構,這樣使用 google.protobuftext_format.Merge 直接將 student_info 一個一個解析成 Student 結構,此時讀取 name、age 和 sex 字段只需要通過 . 屬性即可。# --------------------------------------------------------------
如開頭所言,我們來讀取 students.pbtxt 文件:
student_info {
name: 'Zhang San';
age: 20;
sex: 0;
}
student_info {
name: 'Li Si';
age: 25;
sex: 0;
}
student_info {
name: 'Wang Wu';
age: 18;
sex: 1;
}
# --------------------------------------------------------------
讀取代碼如下:
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 8 19:26:54 2018
@author: shirhe-lyh
"""
import read_student_infoif __name__ == '__main__':
student_info_path = './students.pbtxt'
students_dict = read_student_info.get_student_info_dict(student_info_path)
print(students_dict)
# --------------------------------------------------------------
執行後輸出:
{'Li Si': [25, 0], 'Wang Wu': [18, 1], 'Zhang San': [20, 0]}

TensorFlow 訓練自己的目標檢測器

本文主要描述如何使用 Google 開源的目標檢測 API 來訓練目標檢測器,內容包括:
安裝 TensorFlow/Object Detection API 和
使用 TensorFlow/Object Detection API 訓練自己的目標檢測器。
  • 安裝 TensorFlow Object Detection API
Google 開源的目標檢測項目 object_detection 位於與 tensorflow 獨立的項目 models(獨立指的是:在安裝 tensorflow 的時候並沒有安裝 models 部分)內:models/research/object_detection。models 部分的 GitHub 主頁為:https://github.com/tensorflow/models要使用 models 部分內的目標檢測功能 object_detection,需要用戶手動安裝 object_detection。下面為詳細的安裝步驟:
  • 安裝依賴項 matplotlib,pillow,lxml 等
使用 pip/pip3 直接安裝:$ sudo pip/pip3 install matplotlib pillow lxml其中如果安裝 lxml 不成功,可使用$ sudo apt-get install python-lxml python3-lxml
  • 安裝編譯工具
$ sudo apt install protobuf-compiler
$ sudo apt-get install python-tk
$ sudo apt-get install python3-tk
  • 克隆 TensorFlow models 項目
使用 git 克隆 models 部分到本地,在終端輸入指令:$ git clone https://github.com/tensorflow/models.git克隆完成後,會在終端當前目錄出現 models 的文件夾。
要使用 git(分布式版本控制系統),首先安裝 git:
$ sudo apt-get install git。
  • 使用 protoc 編譯
在 models/research 目錄下的終端執行:$ protoc object_detection/protos/*.proto --python_out=.將 object_detection/protos/ 文件下的以 .proto 為後綴的文件編譯為 .py 文件輸出。
  • 配置環境變量
在 .bashrc 文件中加入環境變量。首先打開 .bashrc 文件:$ sudo gedit ~/.bashrc然後在文件末尾加入新行:export PYTHONPATH=$PYTHONPATH:/.../models/research:/.../models/research/slim其中省略號所在的兩個目錄需要填寫為 models/research 文件夾、models/research/slim 文件夾的完整目錄。保存之後執行如下指令:$ source ~/.bashrc讓改動立即生效。
  • 測試是否安裝成功
在 models/research 文件下執行:$ python/python3 object_detection/builders/model_builder_test.py如果返回 OK,表示安裝成功。
  • 訓練 TensorFlow 目標檢測器
成功安裝好 TensorFlow Object Detection API 之後,就可以按照 models/research/object_detection 文件夾下的演示文件 object_detection_tutorial.ipynb 來查看 Google 自帶的目標檢測的檢測效果。其中,Google 自己訓練好後的目標檢測器都放在:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md可以自己下載這些模型,一一查看檢測效果。以下,假設你把某些預訓練模型下載好了,放在models/ research/ object_detection 的某個文件夾下,比如自定義文件夾 pretrained_models。要訓練自己的模型,除了使用 Google 自帶的預訓練模型之外,最關鍵的是需要準備自己的訓練數據。以下,詳細列出訓練過程:
  • 準備標注工具和文件格式轉化工具
圖像標注可以使用標注工具 labelImg,直接使用$ sudo pip install labelImg安裝。另外,在此之前,需要安裝它的依賴項 pyqt4:$ sudo apt-get install pyqt4-dev-tools(另一依賴項 lxml 前面已安裝)。要使用 labelImg,只需要在終端輸入 labelImg 即可。為了方便後續數據格式轉化,
還需要準備兩個文件格式轉化工具:xml_to_csv.pygenerate_tfrecord.py
它們的代碼分別列舉如下(它們可以從資料 [1] 中 GitHub 項目源代碼鏈接中下載。其中為了方便一般化使用,已經修改 generate_tfrecord.py 的部分內容使得可以自定義圖像路徑和輸入 .csv 文件、輸出 .record 文件路徑,以及 6 中的 xxx_label_map.pbtxt 文件路徑):
(1) xml_to_csv.py 文件源碼:
(2) 修改後的 generate_tfrecord.py 文件源碼:generate_tfrecord.py 也可以由 models/research/object_detection/dataset_tools 文件夾內的相關 .py 文件修改而來。
  • 創建工作目錄,收集圖片
在 Ubuntu 中新建項目文件夾,比如 xxx_detection (xxx 自取,下同),在該文件夾內新建文件夾 annotations,data,images,training。將所有收集到的圖片放在 images 文件夾內。
  • 標注圖片生成 xml 文件
利用標注工具 labelImg 對所有收集的圖片進行標注,即將要檢測的目標用矩形框框出,填入對應的目標類別名稱,生成對應的 xml 文件,放在 annotations 文件夾內
  • 將所有的 .xml 文件整合成 .csv 文件
執行 xml_to_csv.py(放在 xxx_detection文件夾下),將所有的 xml 標注文件匯合成一個 csv 文件,再從該 csv 文件中分出用於訓練和驗證的文件 train.csv 和 val.csv(分割比例自取),放入 data 文件夾。
  • 將 .csv 文件轉化成 TensorFlow 要求的 .TFrecord 文件
將 generate_tfrecord.py 文件放在 TensorFlow models/research/object_detection 文件夾下,在該文件夾目錄下的終端執行:$ python3 generate_tfrecord.py --csv_input=/home/.../data/train.csv  
--images_input=/home/.../images
--output_path=/home/.../data/train.record
--label_map_path=/home/.../training/xxx_label_map.pbtxt
類似的,對 val.csv 執行相同操作,生成 val.record 文件。
(其中 xxx_label_map.pbtxt 文件見下面)
  • 編寫 .pbtxt 文件
仿照 TensorFlow models/research/object_detection/data 文件夾下的 .pbtxt 文件編寫自己的 .pbtxt 文件:對每個要檢測的類別寫入item {
id: k
name: ‘xxx’
}
其中 item 之間空一行,類標號從 1 開始,即 k >= 1。將 .pbtxt 文件命名為 xxx_label_map.pbtxt 並放入training 文件夾
  • 配置 .config 文件
從 TensorFlow models/research/object_detection/samples/configs 文件夾內選擇合適的一個 .config 文件複製到項目工程的 training 文件夾內,將名稱改為與工程相關的 保留模型名 _xxx.config (其中保留模型名為原 .config 文件關於模型的命名字段,建議命名時保留下來,xxx 為與項目相關的自己命名字段),打開文件作如下修改:(1)修改模型參數將 model {} 中的 num_classes 修改為工程要檢測的類別個數。另外,也可以修改訓練參數:train_config: {} => num_steps: xxx => schedule {} => step = xxxnum_steps 表示將要訓練的次數,刪除這一行為不確定次數訓練 (隨時可用 Ctrl+C 中斷),後面的 step 表示學習率每過 step 步後進行衰減。這些參數由自己的經驗確定,也可以使用默認值。其它參數一般不需要修改。(2)修改文件路徑將 .config 文件中所有的 ’PATH_TO_BE_CONFIGURED’ 文件路徑修改為相應的 .ckpt (預訓練模型文件路徑),.record,.pbtxt 文件所在路徑。將修改後的 保留模型名_xxx.config 文件放在 training 文件夾內。
  • 開始本地訓練目標檢測器
在 TensorFlow models/research/object_detection 目錄下的終端執行:$ python3 model_main.py --model_dir=/home/.../training 
--pipeline_config_path=/home/.../training/保留模型名_xxx.config
進行模型訓練,期間每隔一定時間會輸出若干文件到 training 文件夾。在訓練過程中可使用 Ctrl+C 任意時刻中斷訓練,之後再執行上述代碼會從斷點之處繼續訓練,而不是從頭開始 (除非把訓練輸出文件全部刪除)
  • 查看實時訓練曲線
在任意目錄下執行:$ tensorboard --logdir=/home/.../training打開返回的 http 鏈接查看 Loss 等曲線的實時變化情況。
  • 導出 .pb 文件用於推斷
模型訓練完後,生成的 .ckpt 文件已經可以調用進行目標檢測。也可以將 .ckpt 文件轉化為 .pb 文件用於推斷。在 TensorFlow models/research/object_detection 目錄下的終端執行$ python3 export_inference_graph.py --input_type image_tensor  
--pipeline_config_path /home/.../training/pipeline.config
--trained_checkpoint_prefix /home/.../training/model.ckpt-200000
--output_directory /home/.../training/output_inference_graph
執行上述代碼之後會在 /home/.../training 文件夾內看到新的文件夾 output_inference_graph,裡面存儲著訓練好的最終模型,如直接調用的用於推斷的文件:frozen_inference_graph.pb。其中命令中 model.ckpt-200000 表示訓練 200000 生成的模型,實際執行上述代碼時要修改為自己訓練多少次後生成的模型。其它路徑和文件(夾)名稱也由自己任意指定。
  • 調用訓練好的模型進行目標檢測
調用 frozen_inference_graph.pb 進行目標檢測請參考 TensorFlow models/research/object_detection 文件夾下的 object_detection_tutorial.ipynb 。但該文件只針對單張圖像,對多張圖像不友好,因為每檢測一張圖像都要重新打開一個會話(語句 with tf.Session() as sess每張圖像執行一次),而這是非常耗時的操作。可以改成如下的形式:
這樣改動之後,有好處也有壞處,好處是處理視頻或很多圖像時只生成一次會話節省時間,而且從原文件中去掉了語句:sys.path.append("..")
from object_detection.utils import ops as utils_ops
from utils import label_map_util
from utils import visualization_utils as vis_util
使得在任意目錄下都可以執行。
壞處是:上述代碼沒有使用 label_map_util 和 vis_util 等這些 object_detection 伴隨的模塊,使得檢測結果顯示的時候只能自己利用 OpenCV 來做,而存在一個較大的缺陷:不能顯示檢測出的目標的類別名稱。

…… 進行中

https://medium.com/ran-ai-deep-learning,ran1988mail@gmail.com

網誌所有文章總目錄個人簡歷 (Personal Resume)

--

--

Ran
Ran ( AI Deep Learning )

Senior Electronic R&D Manager。(DL Algorithm、software and hardware),ran1988mail@gmail.com,https://medium.com/ran-ai-deep-learning