如何在 Tensorflow 中使用 Dataset

任書瑋
Data Scientists Playground
6 min readDec 11, 2019

最早直接使用feed_dict的方式來輸入資料, 以前使用這個用法時都直接把所有資料儲存在記憶體, 大量的記憶體的使用量可能是最直觀的壞處,其他還有速度慢等等, 好處是方便和可以直接用print來debug

以下紀錄要在 Tensorflow 1.xx版用 Dataset的例子

任務為影像辨識應用於驗證碼上, 分別有訓練集與驗證集兩個資料夾, 資料夾內部為固定大小的圖片檔, 檔名就是驗證碼, 驗證碼的長度介於3~5個字元, 驗證碼的範圍a~z, 以下是範例

先把圖片的路徑和對應的label字串儲存在list裡面, 之後會用到

def _readfile(self):
imgs = os.listdir(folderPath)
for img in imgs:
label = img.split(".")[0]
labels.append(label)
imgPaths.append(self.folderPath + '/' + img)
numberImgs = len(imgs)

接下來定義parse_function, 在這裡要做三件事, 數字0保留給之後padding使用, 這裏我使用tf.py_func處理字串

1.依圖片路徑讀取圖片並做處理

2.將label字串轉成數字, label2int(labelString)

3.將label字串刪除最後並平移(abc -> #ab), 最後轉換成數字, 用意是在deocoder的teacher forcing 輸入, #代表decoder最一開始的輸入, 像是start token, label2intShift(labelString)

def _parse_function_train(self, filename, label):
def label2int(labelString):
labelString = str(labelString, encoding="utf-8")
a = [ [cfg.CHARSET.index(ch) + 1 for ch in labelString] ]
return a
def label2intShift(labelString):
labelString = str(labelString, encoding="utf-8")
labelStringShift = "#" + labelString[:-1]
b = [[cfg.CHARSET.index(ch)+1 for ch in labelStringShift] ]
return b
img_string = tf.read_file(filename)
img_decoded = tf.image.decode_png(img_string, channels=3)
img_resized = tf.image.resize_images(img_decoded, [cfg.width, cfg.height])
label2int_ = tf.py_func(label2int, [label], tf.int64)
label2intShift_ = tf.py_func(label2intShift, [label], tf.int64)
return {"img":img_resized / 256, "label2int":label2int_, "label2intShift":label2intShift_}

定義placeholder, 給dataset使用, 這裡首先要注意的是使用padded_batch來處理label長度在一個batch裡有長有短的問題, 這裡使用None來補到那個batch最大長度就好, 其中key要跟上面的parse_function回傳值對應

epoch = tf.placeholder(tf.int64)
batch_size = tf.placeholder(tf.int64)
x = tf.placeholder(dtypes.string, shape=[None])
y = tf.placeholder(dtypes.string, shape=[None])
dataset = tf.data.Dataset.from_tensor_slices((x ,y))
dataset = self.dataset.map(self._parse_function_train)
dataset = dataset.shuffle(buffer_size).padded_batch(batch_size, padded_shapes={"img":[160, 60, 3], "label2int":[None], "label2intShift":[None]}).repeat(self.epoch)

接下來定義iterator, 注意的是使用make_initializable_iterator(), 這樣就能搭配前面的placeholder, 可以在執行initializer operation 後更改不同的batch size或切換data 來源

iterator = dataset.make_initializable_iterator()
next_batch = iterator.get_next()

執行initializer operation完後就完成了

sess.run(iterator.initializer, feed_dict={x:imgPaths, y:labels, batch_size:batch_size, epoch:epoch})

Reference

--

--