動手做一個機器學習 API

Tsai Yi Lin
Taiwan AI Academy
Published in
10 min readJan 15, 2020

開發的最後一哩路部署應用。當我們訓練好一個模型後,下一步該怎麼做呢?你可以發揮一些巧思進行前後端整合,例如製作一個APP或是以網頁型態,有個操作介面視覺化的呈現。本篇文章以手寫數字為例,使用資料降維以及機器學習的方法來訓練一個預測模型,最後在使用Python Flask 框架搭建一個 RESTful API。

何謂非監督式學習?

非監督式學習只給定特徵機器會想辦法會從中找出規律。在非監督式學習中資料的來源就非常的重要,通常我們會先做探索式資料分析(Exploratory Data Analysis,簡稱EDA)來做資料的視覺分析。做EDA有他的好處,首先你可以對資料有初步的認識,並且可以透過資料視覺化來觀察分布狀態。例如資料常態分佈或是離散分佈而做出相對應的資料清理與正規化等。監督式的學習最常見的方法就是集群分析(Cluster Analysis),目標是根據特徵(feature)就把資料樣本分為幾群。

降維 (Dimension Reduction)

一般資料常見的表示方法有1維(數線)、2維(XY平面)、3維(XYZ立體)。大於三維的數據就難以視覺化呈現,那麼我們該如何去表示高維度的資料但又不能壓縮原本資料間彼此的關連性呢?這時降維就能幫助你了!降維顧名思義,就是原本的Data處於在一個比較高的維度作標上,我們希望找到一個低維度的作標來描述它,但又不能失去Data本身的特質。

為什麼要降維?

  • 如果我們能把一些資料做壓縮,又能夠保持資料原來的的性質。因此我們可以用比較少的空間,或是計算時用比較少的資源,就可以得到跟沒有做資料壓縮之前得到相似的結果。
  • 此外做資料降維可以幫助資料視覺化,二維可以用xy平面圖表示,三維可以用xyz立體圖作表示,而大於三維的空間難以視覺化做呈現。

兩種常見的降維演算法

  • Principal component analysis (PCA)
  • T-Distributed Stochastic Neighbor Embedding (t-SNE)

PCA & t-SNE 比較

PCA和t-SNE是兩個不同降維的方法,PCA的優點在於簡單若新的點要映射時直接代入公式即可得出降維後的點。若t-SNE有新的點近來時我們沒有去計算新的點和舊的點之間的關係因此我們無法將新的點投影下去。t-SNE的優點是可以保留原本高維距離較遠的點降維後依然保持遠的距離,因此這些群降維後依然保持群的特性。

實際範例: MNIST手寫數字辨識

MNIST數據庫是一個手寫數字的大型數據庫,是由Yann LeCun所蒐集,這位大神同時也是Convulution Nueral Networks(卷積神經網絡)的創始人,因此享有卷積網絡之父的美稱。MNIST資料集是由60,000筆訓練資料、10,000筆測試資料所組成,資料集裡的每一筆資料皆由images(數字的影像)與labels(答案)所組成。今天我們不透過深度學習(DCNN)的方法來實作,而是搭配本文先前所介紹的方法降維來分析與分群分類。

1) 步驟ㄧ:載入mnist datasets

首先我們載入keras所提供的mnist datasets,並將28*28像素的照片轉換成一維Numpy array。緊接著將所有資料正規化(除以255),此方法目的是減少他的計算量。除以255這個運算在圖像處理中很常見,因為顏色的值實際上是一個規格化的數,如果以字節整數來表示的話,範圍是0-255之間。

2) 步驟二:使用t-SNE降維

透過降維的方式我們可以將784維的資料轉換成2維並投射在平面上,並且可以使用平面座標圖來觀察分佈狀況。還記得為什麼我們要使用t-SNE嗎?因為t-SNE允許非線性的轉換,此外t-SNE使用了更複雜的公式來表達高維與低維之間的關係。因此在這種0~9有十個分類的情況下可以確保彼此間的距離會被區隔該而不會重疊。

這裡我們這裡採用 sklearn 所提供的 TSNE 訓練模型,在官方的API中提供了三個我們可以調參的參數。第一個是 n_components 即為降維之後的維度因為我們要投射到平面座標上因此維度設為2。第二個是 random_state 為最佳化過程中考慮鄰近點的多寡,這項參數我們設定42(官方建議5-50之間)。最後一個是 n_iter 迭代次數,這邊我們先設定5000代。下圖即為降維後的結果,我們使用了keras所提供的mnist datasets中的訓練資料共六萬筆來做t-SNE降維。從結果圖中我們可以很清楚的將這六萬張手寫辨識圖片清楚的分類與投影到XY平面座標上。因為mnist的資料庫當中已經將每一個訓練圖片都坐上標籤(標準答案),因此我們降維後直接透過 matplotlib 將降維後的XY與label標籤繪圖呈現,若在沒有標籤答案的狀況下我們可以使用分群演算法來做集群式的分析與分類。常見的非群演算法有 K-MeansHierarchical Clustering。這也是非監督式學習常見的演算法,因為文章篇幅關係這次就不探討分群改留到下次討論。

3) 步驟三:使用XGBoost學習一個預測降維後的結果

為什麼我們需要訓練一個模型來模擬預測t-SNE降為後的輸出呢?因為大筆資料的降維需要浪費很多計算資源與時間,第二步驟光是降維六萬筆資料就耗時6小時(CPU版本)。因為在t-SNE降維是要根據所有的資料去乎相比較計算出來的。因此如果你需要降維新的一筆資料就必須再從頭來過也就是要再花6小時跑降維。因此這裡有個方法我們可以利用監督式學習的方式,目標學出一個模型可以直接將784維的(input)資料,預測t-SNE降維後的2D資料。這裡使用 XGBoost 來實作,XGBoost 全名為eXtreme Gradient Boosting(極限梯度提升)。是一種提升樹模型,也是屬於Ensemble learning(集成學習)的方法,它是將許多樹模型集成在一起,形成一個很強的模型。

這一步驟就是使用 XGBoost 回歸方式來訓練一個模型,這邊的輸入為一張28*28像素的照片轉換成一維Numpy array也就是784個特徵的一維資料。而輸出為二維的數據,即為預測該張圖經過t-SNE降維後的結果。下圖是經過 XGBoost 回歸所訓練出來的模型,左邊為內部測試(training data),右邊為外部測試(testing data)。我們可以從圖中看出來我們訓練的模型在訓練集和測試集都表象的不錯,還是能從平面圖中快速的將這些照片分類並區隔出來。

# 訓練結果
- Score: 0.99981
- MAE: 0.8489442
- MSE: 1.3104383
- RMSE: 1.1445

4) 步驟四:使用XGBoost訓練一個手寫數字分類器

剛剛透過XGBoost的回歸學習訓練出一個模擬t-SNE的降維結果,接著我們要訓練一個XGBoost的分類器來預測手寫辨識的圖片。

# 訓練結果
- 內部測試: Accuracy: 95.15%
- 外部測試: Accuracy: 95.19%

5) 步驟五:儲存模型

我們訓練好了一個模型以後需要保存並讓下一次直接預測。常見的兩種保存Model的模塊有 pickle 與 joblib,範例使用 pickle 。由於 pickle 儲存模型後容量可能會有好幾百MB因此建議可以透過 gzip 來壓縮模型並儲存。

使用 Python Flask 建置一個手寫數字辨識 API

兩個模型都訓練完成後,我們就可以實際整合到專案中展示。展示的方式有很多種像是網頁或是手機APP,通常這邊我們都會需要建立一個後端API來負責處理這一段的任務。Flask是一個使用Python編寫的輕量級Web應用框架。

系統架構與流程

  1. 網頁前端收到手寫圖片轉換成base64編碼格式傳給API。
  2. 後端API接收到訊息將base64轉換成cv2 ndarray
  3. BGR→GRAY變成28*28二維矩陣,並做正規化除以255
  4. 28*28二維矩陣變形成一維矩陣(784)
  5. 進行模型降維和預測

部署

俗稱開發的最後一哩路 ‘部署應用’。Heroku是一個支援多種程式語言的雲平台即服務。我們可以透過此平台建立一個免費帳號,免費帳號的雖然有限制但是足夠我們小型專案使用。免費帳號只能提供五個專案建立,此外30分鐘沒有使用會進入睡眠狀態,之後要開啟需要等待一些時間才能運作。最後還有500MB的儲存空間限制,這空間對開發者開發一些小專案來說夠用了。

  • Python語言部署到Heroku平台
  • 在Heroku建立專案
  • 使用Git自動部署並與GitHub連動
  • 在Heroku專案內新增 buildpack

前後端整合

Demo

Python Flask 新手包

我在這邊提供一個我之前建立的 Python Flask 開發懶人包,簡易的MVC框架。開法者可以快速的進行 Python Flask API 的開發。

API
┌── config.py // 環境變數設置

├── app
│ ├── controllers // 處理控制流程和回應
│ ├── modules // 後端資料庫進行運作or載入機器學習模型
│ └── __init__.py // 各路徑的設定點

└── run.py // 程式進入點

GitHub: flask-api-starter-kit

Gunicorn 部署介紹 : https://www.youtube.com/watch?v=rEWtDVAHb4U&t=268s

--

--