MLflow custom model flavor
การใช้งาน MLflow กับ Custom model flavor
เพื่อความเข้าใจในบทความนี้มากขึ้นอยากให้กลับไปอ่าน เก็บ Model ของคุณไว้ใช้งานด้วย MLflow
ในบทความนี้เราจะใช้ feature ของ MLflow ในการจัดการกับ custom model flavor ซึ่งจะ implement การ save model, จัดเก็บ ML model metric และการ load model โดยใช้ wrapper class ที่ inherited มาจาก mlflow.pyfunc.PythonModel เพื่อจัดการกับ output ของ model prediction ให้ออกมาตามที่เราต้องการ
การ train model ครั้งนี้จะใช้ Classify images of clothing เหมือนบทความ เก็บ Model ของคุณไว้ใช้งานด้วย MLflow แต่จะมีการ save file เพิ่มเติมและปรับ predict function เพื่อให้มีชื่อ class ออกมาด้วยในการ predict
# code structure
|- wrapper
|- model_wrapper.py
|- custom_mlflow_tracking.ipynb
เขียน wrapper function เพื่อกำหนด output ในการ predict
# model_wrapper.py
import mlflow
import tensorflow as tf
class ModelWrapper(mlflow.pyfunc.PythonModel):
"""
Class to train and use scratch model
"""
def load_context(self, context):
"""This method is called when loading an MLflow model with pyfunc.load_model(), as soon as the Python Model is constructed.
Args:
context: MLflow context where the model artifact is stored.
"""
## load model implement
import json
with open(context.artifacts["class_names_path"], "r") as fp:
self.class_names = json.load(fp)
self.model = tf.keras.models.load_model(context.artifacts["artifact_path"])
def predict(self, context, model_input):
"""Evaluates a pyfunc-compatible input and produces a pyfunc-compatible output. For more information about the pyfunc input/output API, see the Inference API.
Args:
context ([type]): MLflow context where the model artifact is stored.
model_input ([type]): the input data to fit into the model.
Returns:
[type]: the loaded model artifact.
"""
# return your model prediction
import numpy as np
probability_model = tf.keras.Sequential([self.model,
tf.keras.layers.Softmax()])
prob_predict_result = probability_model.predict(model_input)
predict_class_name = [self.class_names[np.argmax(item)] for item in prob_predict_result]
predict_prob_max = [np.max(item) for item in prob_predict_result]
result = [(class_name,prob) for class_name,prob in zip(predict_class_name,predict_prob_max)]
return result
ใน class ModelWrapper เราจะ Override 2 method คือ
- load_context : เป็น method ที่ใช้ load model และ file ต่างๆเพื่อนำไปใช้ในการ predict
- รับ model file และ class name file จาก context.artifacts แล้วอ่านไฟล์ - predict : เป็น method ที่ implement ตาม predict function ของแต่ละ ML framework
- ทำ output ให้เป็น probability โดยใช้ softmax function
- map output prediction กับ class name
- คืนค่า result ออกเป็น class name และ probability
ส่วนการเขียน code เพื่อ tracking model ทำดังนี้
from datetime import datetime
import json
import mlflow
from mlflow.models.signature import infer_signature
import tensorflow as tf
TRACKING_URI = 'http://localhost:5000'
EXPERIMENT_ID = 326755177053281989
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
wrapper_model = ModelWrapper()
mlflow.set_tracking_uri(TRACKING_URI)
run_tag = {'model':'clothing_classification','framework':'tensorflow'}
now = datetime.now()
run_name = f"run_name_{now.strftime('%m%d%Y%H%M%S')}"
with mlflow.start_run(experiment_id=EXPERIMENT_ID,tags=run_tag) as run_mlflow:
mlflow.set_tag("mlflow.runName", run_name)
# train model
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=10)
# evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
# log metric
metrics = {"test_loss":test_loss,
"test_accuracy":test_acc
}
mlflow.log_metrics(metrics)
# log model summary as a text file
model_summary_path = 'model_summary.txt'
stringlist = []
model.summary(print_fn=lambda x: stringlist.append(x))
model_summary = "\n".join(stringlist)
mlflow.log_text(model_summary,model_summary_path)
# save classname file
class_names_path = "class_names.txt"
with open(class_names_path, "w") as fp:
json.dump(class_names, fp)
# manual save model file
model_path = 'save_model'
model.save(model_path)
# input-output signature
signature = infer_signature(train_images)
# serve model
artifacts = {"class_names_path": class_names_path,'artifact_path':model_path}
mlflow.pyfunc.log_model(
artifact_path='model',l
artifacts=artifacts,
#code_path = wrapper function dir name เพื่อ เก็บ model_wrapper.py ไว้บน mlflow
code_path=['wrapper'],
python_model=wrapper_model,
signature=signature
)
train model โดยใช้ MLflow เพื่อ tracking
mlflow.set_tracking_uri()
ติดต่อผ่าน http://localhost:5000- ใช้
mlflow.log_params()
เพื่อเก็บ parameters จาก model - ใช้
mlflow.log_metrics()
เพื่อเก็บ output metric จาก model - save text file ของ class name
- save model file โดยใช้ built-in function ของ keras
- เขียน signature model เพื่อเป็นตัวอย่าง input/output model ได้โดยใช้ module infer_signature
- ใช้
mlflow.pyfunc.log_model()
เพื่อ tracking model on MLflow
- artifact_path: ชื่อ folder เก็บ model file
- artifacts: json เก็บ key file ที่ต้องการ save บน artifact storage (ในที่นี้ มี 2 path file ที่ต้องการ serve ขึ้นคือ class name text file และ saved model file)
- python_model: python class model wrapper
- signature: model signature(กำหนด input/output format)
หลังจาก train model แล้วจะได้ model ที่เก็บบน MLflow และไฟล์ต่างๆ ดังภาพที่ 1
เมื่อเราอยาก load model มาใช้งาน ใน custom model flavor เราเพียงใช้คำสั่ง mlflow.pyfunc.load_model(model_uri)
เพื่อเรียกใช้ model
YOUR_RUN_ID = 'xxx'
logged_model = f'runs:/{YOUR_RUN_ID}/model'
loaded_model = mlflow.pyfunc.load_model(logged_model)
loaded_model.predict(test_images[0:5])
จะเห็นว่าการทำ custom model flavor เราจะเก็บการทดลองจาก framework ไหนก็ได้ตามที่เราต้องการ นอกจากนี้ยังสามารถเขียน wrapper function เพื่อกำหนด output เพื่อทำ post-process จาก model ของเราได้อีกด้วย
ตัวอย่าง code ดูได้ที่