Flask API for 訓練好的預測模型


需求描述:想知道怎麼讓人接到自己訓練好的預測模型,使得新資料進來時能夠直接即時做預測。


之前只碰過 Django,一直耳聞 Flask 能夠提供更輕量的 Web Service,但一直沒有實際試過;除此之外,也一直很好奇該怎麼讓他人接到自己訓練好的模型來做即時預測。剛好最近看到了這兩篇文章:A Flask API for serving scikit-learn modelsDeploying a Machine Learning Model as a REST API,於是決定捲起袖子來簡單測試看看。

模型準備

這次重點不在資料集或是模型本身,因此決定以最簡單的鐵達尼號資料集來做測試,也學參考文章的作者一樣只看三個變數:Age, Sex, Embarked。好,那麼首先就把資料讀進來:

import pandas as pd
df = pd.read_csv('train.csv')

接著把類別型變數編碼:

df_ohe = pd.get_dummies(df[['Age', 'Sex','Embarked', 'Survived' ]], columns=['Sex','Embarked'], dummy_na=True)

然後隨便補空值(真的隨便,勿學):

for col in df_ohe:
df_ohe[col].fillna(0, inplace=True)

接著準備好變數跟預測目標:

dependent_variable = 'Survived'
x = df_ohe[df_ohe.columns.difference([dependent_variable])]
y = df_ohe[dependent_variable]

叫出隨機森林並訓練模型:

from sklearn.ensemble import RandomForestClassifier as rf
clf = rf()
clf.fit(x, y)

把訓練好的模型,以及模型所需的欄位名稱存起來:

from sklearn.externals import joblib
joblib.dump(clf, 'model.pkl')
model_columns = list(x.columns)
joblib.dump(model_columns, 'model_columns.pkl')

Flask API

接著來把我們的預測模型包成 Rest API,首先把需要的套件 import 進來,並初始化我們的 APP:

from flask import Flask, jsonify, request
from sklearn.externals import joblib
import pandas as pd
app = Flask(__name__)

然後是簡單的資料處理跟模型預測的 function,其中取 request 裡面的值讓我 suffer 了一下,好險有這位大神的解答,才順利取到要的東西:

@app.route('/predict', methods=['POST'])
def predict():
# 取得傳入的參數並轉成 DataFrame
json_ = request.form.to_dict()
query_df = pd.DataFrame([json_])
# 簡單的資料處理(編碼類別型變數、確認欄位、補值等等)
query = pd.get_dummies(query_df)
model_columns = joblib.load('model_columns.pkl')
query = query.reindex(columns=model_columns, fill_value=0)
for col in model_columns:
if col not in query.columns:
query[col] = 0
# 預測
prediction = clf.predict(query).tolist()
return jsonify({'prediction': prediction})

main:

if __name__ == '__main__':
clf = joblib.load('model.pkl')
app.run(port=8080)

測試

開啟服務後就來簡單測試一下,首先先傳 17 歲從 C 港出發的女性資料給模型:

import requests
import pandas as pd
url = 'http://127.0.0.1:8080/predict'
data = {'Age':17, 'Embarked':'C', 'Sex':'female'}
response = requests.post(url, data=data)
response.json()

response 回傳:{‘prediction’: [1]},模型預測會存活。

再來一個 77 歲從 S 港出發的男性:

url = 'http://127.0.0.1:8080/predict'
data = {'Age':77, 'Embarked':'S', 'Sex':'male'}
response = requests.post(url, data=data)
response.json()

response 回傳:{‘prediction’: [0]},模型預測為無法存活。

以上,為簡單的測試。