Flaskを使って機械学習モデルのロード時間を短縮する
機械学習モデルが複雑になってくると、モデルのロードに大量の時間がかかるようになってきます。ローカルサーバーで常時モデルを立ち上げた状態にして、プログラム起動時に入力がローカルサーバーを経由するようにすればモデルのロード時間を待たなくてすみます。
今回はわかりやすいようにモデルとしては軽めの犬猫画像識別モデルで作ってみます。
GitHubにコードをまとめておきます。
前準備
犬猫認識モデルの作成、学習方法、モデルの保存方法は以前の記事「python初心者が作る、犬猫認識AI」を参照してください。
またお手持ちの環境にflaskが入っていない場合は以下のようにインストールしておいてください。
pip install flask
ローカルサーバーの構築
###server.py###from flask import Flask, request, jsonify, abort
from keras.preprocessing import image
import numpy as np
from keras.models import load_model
app = Flask(__name__)
import tensorflow as tf
graph = tf.get_default_graph()# /api/predict にPOSTリクエストされたら予測値を返す関数@app.route('/api/predict', methods=["POST"])
def predict():
global graph
with graph.as_default():
try:
# run.pyからJSONの情報を受け取る
image_path = request.json["image_path"] # 画像を読み込んで配列へ
img = image.load_img(image_path, target_size=(150, 150))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = x / 255.0 # 予測
result_predict = model.predict(x) # 予測をJSONへ
response = {"status": "OK", "prediction": float(result_predict[0])} # JSONの情報をrun.pyに返す
return jsonify(response) except Exception as e:
print(e)
abort(400)# abort(400)した時のレスポンス
@app.errorhandler(400)
def error_handler(error):
response = {"status": "Error", "message": "Invalid Parameters"}
return jsonify(response)
if __name__ == "__main__":
# 立ち上げたままにしたいモデルのロード
model = load_model('dog_cat.h5')
# ローカルサーバーの起動
app.run(host="localhost")
app.routeの中で特にエラー無く処理が完了した場合は
"status":"OK","prediction":float(result_predict[0])
が戻されます。
例外処理で弾かれた場合、app.errorhandlerを通り
"status":"Error”,"message":”Invalid Parameters”
が戻ります。
import tensorflow as tf
graph = tf.get_default_graph()
global graph
with graph.as_default():
の部分はこのサイトとこのサイトを参考にしました。
完全に理解できていませんが、これがないとうまく動かないようです。
model=load_model("dog_cat.h5")
はapp.routeの上に書いても動きますが、モデルを複数ロードする場合はなぜかif __name__ == “__main__”:の下にないといろいろエラーがおきました。
予測した値 result_predictはfloat()で囲っておかないと、値を返すときに配列はJSONで返せませんと怒られます。
ここのエラーを除くのに一番時間がかかりました。
参考にしたサイトはここです。
python server.py
でローカルサーバーが立ち上がります。
テストの準備
###run.py###import requests
import json
import sys# 画像パスは外からあたえてみます
IMAGE_PATH = ' '.join(sys.argv[1:])# 〇〇〇〇〇の部分はローカルサーバーのアドレスに置き換える
URL = "http://〇〇〇〇〇/api/predict"
DATA = {"image_path": IMAGE_PATH}# DATAの情報をローカルサーバーに送信
response = requests.post(URL, json=DATA)# 返ってきたJSON情報を出力してみる
print(response.text)# JSONをdict型に変換する
result = json.loads(response.text)# [" "]で情報を取り出す
print(result["prediction"])# 犬猫判定
if result["prediction"] < 0.5:
print("This is cat")
if result["prediction"] > 0.5:
print("This is dog")
URL=“http://〇〇〇〇〇/api/predict"の”〇〇〇〇〇”の部分は先ほど立ち上げたローカルサーバーのアドレスにしてください。
ローカルサーバーのアドレスがhttp://localhost:5000/の場合の例
URL = "http://localhost:5000/api/predict"
テストの実施
python run.py test_image_path
軽いモデルだとローカルサーバー経由は逆に時間がかかるかもしれませんが、モデルが大きいもので試してもらえればありがたみがわかると思います。
少し余談
import unittest でいろいろ便利なテストケースを作成できます。
unittestの使い方はここを参照してください。以下はその例
# run.pyimport unittest
import requests
import json
class APITest(unittest.TestCase):
URL = "http://〇〇〇〇〇/api/predict"
DATA = {"image_path": "test_path"}
def test_normal_input(self):
response = requests.post(self.URL, json=self.DATA)
print(response.text)
result = json.loads(response.text)
print(result["prediction"])
# ステータスコードが201かどうか
self.assertEqual(response.status_code, 201)
# statusはOKかどうか
self.assertEqual(result["status"], "OK")
# 非負の予測値があるかどうか
self.assertTrue(0 <= result["prediction"])
if result["prediction"] < 0.5:
print("This is cat")
if result["prediction"] > 0.5:
print("This is dog")
if __name__ == "__main__":
unittest.main()# server.pyfrom flask import Flask, request, jsonify, abort
from keras.preprocessing import image
import numpy as np
from keras.models import load_model
app = Flask(__name__)
import tensorflow as tf
graph = tf.get_default_graph()
@app.route('/api/predict', methods=["POST"])
def predict():
global graph
with graph.as_default():
try:
image_path = request.json["image_path"]
img = image.load_img(image_path, target_size=(150, 150))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = x / 255.0
result_predict = model.predict(x)
response = {"status": "OK", "prediction": float(result_predict[0])}
return jsonify(response), 201
except Exception as e:
print(e)
abort(400)
@app.errorhandler(400)
def error_handler(error):
response = {"status": "Error", "message": "Invalid Parameters"}
return jsonify(response), error.code
if __name__ == "__main__":
model = load_model('dog_cat.h5')
app.run(host="localhost")