使用 SimNet 的 Quora 語意匹配

Chao-Hsuan Ke
小小實驗室
Published in
7 min readJan 15, 2020

使用 SimNet 的針對 Quora 資料集 (pointwise 格式) 進行語意匹配的判斷,可以參考官方兩篇文章 ,但因為過程中多少遇到了一些問題,因此重新把完整的過程再寫一次。

以及

這邊是用 docker 先建立起一個 ubuntu 環境,因此所有的安裝過程皆可以在 ubuntu 機器上重構一次 。由於這篇文章是針對 Quora 資料集進行,所以官方原本的範例資料就先移除掉,只保留下 Quora 的語意匹配資料。

環境

Ubuntu 18.04.3 LTS

Python 2.7.17

Tensorflow 1.14.0

安裝

下載 AnyQ

將下載回的程式碼放在 /tmp 資料夾中,完整的 SimNet 程式碼位置我們預設放在/tmp/AnyQ/tools/simnet/

SimNet 的結構如下:

/tmp/AnyQ/tools/simnet
|-tf
|- date // 資料集
|- examples // 配置文件
|- layers // 網路操作層
|- losses // 損失函數
|- nets // 網路結構
|- tools // 資料存取與評估
|- util // 工具區 - 資料轉換
|- model // model 輸出位置
|- pointwise

若從官方版下載程式後不會有 model/pointwise 的資料夾,需要自行建立。

下載 Quora 資料集

到 Kaggle 官方網站下載,請點擊下方連結自行下載,這邊只需要下載 train.csv 的部分即可。

資料處理

剛才下載回來的 train.csv 並不能做 SimNet 的輸入數據,所以須執行一些前處理。須先把存文字轉變成 dictionary 儲存的方式,之後在將每個文字的 id 利用字嵌入等處理方式轉成向量數值,而 word embedding 是選擇詞袋法(Bag-of-word)的方式。

17 15 10 677 90 13813 75 1248 5328 103 35 169 141 146 17 15 10 677 90 13813 75 1248 5328 103 35 169 1

這邊就省略了如何做轉置的解釋,程式碼內詳細的說明可以參考官方網站上的說明,我們可以直接使用編輯好的 quora.py 這個檔案進行處理。

此時產生出的 train_.tsv 與 test_.tsv 位置在 :

simnet
|-tf
|- date
|- train_.tsv // 訓練集資料
|- test_.tsv // 測試集資料

建立一個名為 run_preprocess.sh 的 bash 檔將 quora.py 產生出的字典檔轉換為詞嵌入格式。

此 shell 會將 train_.tsv 與 test_.tsv 分別轉換為 convert_train_ 與 convert_test_ 兩個檔案。

當完成資料轉換後,就可以開始餵入進行模型的建立。但為了配合 Quora 的資料大小,在設定檔內需要修改相關變數。

變數修改

這邊先用 CNN 進行測試來建立模型,所以就直接採用 /tmp/AnyQ/tools/simnet/train/tf/examples/內的 cnn_pointwise.json 設定檔。這個檔案內有些變數需要修正 :

  • data_size = 323273 #因為 train_.tsv 有 323273 筆樣本
  • vocabulary_size = 1000000
  • batch_size = 800
  • num_epochs = 1
  • print_iter = 10
  • train_file = data/convert_train_
  • test_file = data/convert_test_

以及 tmp/AnyQ/tools/simnet/train/tf/目錄下的 tf_simnet.py

找到def predict(conf_dict),修改如下程式碼:

conf_dict.update({'num_epochs': '1', 'batch_size': '1',
'shuffle': '0', 'train_file': conf_dict['test_file']})

將其修改成

conf_dict.update({'num_epochs': '1', 'batch_size': '400',
'shuffle': '0', 'train_file': conf_dict['test_file']})

增加訓練模型 bash file

當所有的設定檔完成後就可以開始建立模型。由於 SimNet 有 7種的網路可以選擇,為了方便之後區別與執行,可以將每一個模型的訓練與測試都分開來寫成 bash 檔。因為這個範例是使用 CNN 建置網路,因此我們建立一個 run_train_cnn.sh 如下 :

set -e # set -o errexit
set -u # set -o nounset
set -o pipefail

in_task_type='train'
in_task_conf='./examples/cnn_pointwise.json'
python tf_simnet.py \
--task $in_task_type \
--task_conf $in_task_conf

增加驗證模型 bash file

有了訓練模型的 bash 檔,也需要建立一個測試模型的 bash 檔命名為 run_predict_cnn.sh。同樣都在 tmp/AnyQ/tools/simnet/train/tf/ 下,內容如下 :

set -e # set -o errexit
set -u # set -o nounset
set -o pipefail
in_task_type='predict'
in_task_conf='./examples/cnn_pointwise.json'
python tf_simnet.py \
--task $in_task_type \
--task_conf $in_task_conf

執行

執行 run_train_cnn.sh 後會開始建置網路模型,而模型資料會被存在 tmp/AnyQ/tools/simnet/train/tf/model/pointwise 而執行結果如下 :

當模型建立完成後繼續執行 run_predict_cnn.sh 進行測試集資料的語意匹配,會得到一個正確率 Accuracy

以上就是用 SimNet 基於 CNN 進行語意匹配的測試,SimNet 總共有 7 個網路模型可以使用,可以在基於原本的範例檔修改變數後在測試看看誰的效果比較好。

正確率

以下是針對 Quora 資料集使用內建的 7 個網路模型 (變數皆用預設值) 進行訓練後執行測試集資料所得到的正確率。得到最高正確率的是 pyramid,而需要最久訓練時間的是 mmdnn。

--

--

Chao-Hsuan Ke
小小實驗室

永遠熱愛自己的工作,總是找一堆事把自己的時間塞滿。喜歡接觸不同領域,像是 歷史、文化、金融和公共政策 等議題,期許著自己會什麼就分享什麼。