Fine-tuning Flan-T5 Base and online deployment in Vertex AI
Google has released the checkpoints of several open-source LLM models including BERT, T5 or UL2. Amongst the most interesting in terms of performance and results is Flan-T5, a variant of T5 that generalises better and outperforms T5 in many NLP tasks.
This post shows how to fine-tune a Flan-T5-Base model for the SAMSum dataset (summary of conversations in English) using Vertex AI. In particular, you will use Vertex AI Training with a 1xA100 GPU (40 GB HBM) for fine-tuning, and Vertex AI Prediction for online predictions.
The dataset: SAMSum
The SAMSum dataset contains about 16k messenger-like conversations with summaries. Conversations were created and written down by linguists fluent in English. This use case is similar for example to the conversation summaries feature available in Google Chat.
The model: Flan-T5 Base
Flan-T5 is a variant that outperforms T5 on a large variety of tasks. It is multilingual and uses instruction fine-tuning that, in general, improves the performance and usability of pretrained language models, particularly T5.
Flan-T5 has public checkpoints for different sizes. This code sample will use the google/flan-t5-base
version.
Fine-tuning
Using libraries from Hugging Face, you will fine-tune a Flan-T5-Base model on the SAMSum dataset for English conversations. You will use the google/flan-t5-base
, that will be fine-tuned and moved to Vertex AI Model registry.
The model is fine-tuned on Vertex AI with a 1xA100 NVIDIA GPU (40 GB HBM). The code launches a Training pipeline, a type of Vertex AI job, which executes the following three steps sequentially: create a Vertex AI Managed Dataset (not created here because it ‘s downloaded from Hugging Face), run the workload in Vertex AI Training (fine-tuning in our case), and upload the model to Vertex AI Model Registry.
The fine-tuning should take around 45 minutes to complete.
job = aiplatform.CustomTrainingJob(
display_name="flan_t5_base_finetuning_gpu_tensorboard",
script_path="flant5base_trainer.py",
requirements=["py7zr==0.20.4",
"nltk==3.7",
"evaluate==0.4.0",
"rouge_score==0.1.2",
"transformers==4.25.1",
"tensorboard==2.11.2",
"datasets==2.9.0",
"google-cloud-storage==2.7.0"],
container_uri="europe-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-10:latest",
model_serving_container_image_uri="europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-10:latest",
)
model = job.run(
model_display_name='flan-t5-base-finetuning-gpu-tensorboard',
replica_count=1,
service_account = SERVICE_ACCOUNT,
tensorboard = TENSORBOARD_RESOURCE_NAME,
machine_type="a2-highgpu-1g",
accelerator_type="NVIDIA_TESLA_A100",
accelerator_count = 1,
)
Uvicorn
A Custom Container image for predictions is required. Custom Container image requires that the container runs an HTTP server. Specifically, the container must listen and respond to liveness checks, health checks, and prediction requests.
This repo uses FastAPI and Uvicorn to implement the HTTP server. The HTTP server listens for requests on 0.0.0.0
. Uvicorn is an ASGI web server implementation for Python. Uvicorn currently supports HTTP/1.1 and WebSockets. Here is a docker image with Uvicorn managed by Gunicorn for high-performance FastAPI web applications in Python 3.6+ with performance auto-tuning. An uvicorn server is launched with:
uvicorn main:app --host 0.0.0.0 --port 8080
Export model from Vertex AI Model Registry
After fine-tuning, the model is stored in Vertex AI Model Registry. Since you are going to embed the model in a custom Vertex Prediction Container, you must export the model and place it in the local Prediction Container directory predict/model-output-flan-t5-base
directory (for example, with gcloud storage cp
). The predict/
directory contains the Dockerfile
to generate the Custom Prediction Container image.
The model must be available in the predict/model-output-flan-t5-base
directory, with a similar content like this, before building the custom container prediction image. Note both model and tokenizer must be included:
config.json
pytorch_model.bin
source.spm
special_tokens_map.json
target.spm
tokenizer_config.json
training_args.bin
vocab.json
Build Custom Prediction Container image and upload model to Vertex AI Model Registry
Push docker image to Artifact Registry:
gcloud auth configure-docker europe-west4-docker.pkg.dev
gcloud builds submit --tag europe-west4-docker.pkg.dev/argolis-rafaelsanchez-ml-dev/ml-pipelines-repo/finetuning_flan_t5_base
Next step is to upload and deploy the model to Vertex AI Prediction. You will use the previous image with this sample code. Note we are not using GPU for prediction:
DEPLOY_IMAGE = 'europe-west4-docker.pkg.dev/argolis-rafaelsanchez-ml-dev/ml-pipelines-repo/finetuning_flan_t5_base'
HEALTH_ROUTE = "/health"
PREDICT_ROUTE = "/predict"
SERVING_CONTAINER_PORTS = [8080]
model = aiplatform.Model.upload(
display_name=f'custom-finetuning_flan_t5_base',
description=f'Finetuned Flan T5 model with Uvicorn and FastAPI',
serving_container_image_uri=DEPLOY_IMAGE,
serving_container_predict_route=PREDICT_ROUTE,
serving_container_health_route=HEALTH_ROUTE,
serving_container_ports=SERVING_CONTAINER_PORTS,
)
# Retrieve a Model on Vertex
model = aiplatform.Model(model.resource_name)
# Deploy model
endpoint = model.deploy(
machine_type='n1-standard-4',
sync=False
)
endpoint.wait()
Online predictions
Predict using the Vertex AI Python SDK:
# Retrieve an Endpoint on Vertex
print(endpoint.predict([[sample["dialogue"]]]))
# Output:
# Prediction(predictions=[[["Patti's cat is fine. Patti will pick her up later. Patti will fetch the cage after work."]]],
# deployed_model_id='4495568794441220096', model_version_id='1', model_resource_name='projects/989788194604/locations/europe-west4/models/2128205910630203392', explanations=None)
Predict using Vertex AI REST API:
curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" \
https://europe-west4-aiplatform.googleapis.com/v1/projects/989788194604/locations/europe-west4/endpoints/7028544517674369024:predict \
\
-d "{\"instances\": [\"
Greg: Hi Mum, how's the cat doing?
Patti: I just rang the vets, she's fine!
Greg: Thank God, been worrying about her all day!
Patti: They said I can pick her up later. I'll pop home and fetch the cage after work. Should be there at 5ish.
Greg: Good, see you at home, bye!
\"]}"
# Output
{
"predictions": [
[
[
"Patti will pick up the cat later. Patti will pop home and fetch the cage after work."
]
]
],
"deployedModelId": "4495568794441220096",
"model": "projects/989788194604/locations/europe-west4/models/2128205910630203392",
"modelDisplayName": "custom-finetuning-flan-t5-base",
"modelVersionId": "1"
}
Summary and acknowledgements
This post summarizes how to fine-tune and deploy a Flan-T5 Base model to an online endpoint in Vertex AI.
It should be noted that this example does not include model parallelism or distributed training, since the model could be considered as “small” (approx. 250M parameters). However, for bigger versions of Flan-T5, like google/flan-t5-xxl
, distributed training will be required.
You can find the full code in this repo.
I would like to thank Camus Ma for comments and contributions to this post.
References
[1]
Google research blog: The Flan Collection: Advancing open source methods for instruction tuning[2]
Flan-T5 paper[3]
Phil Schmid blog: Fine-tune FLAN-T5 for chat & dialogue summarization[4]
Vertex AI Training: Create a model using a custom training[5]
Vertex AI Prediction: Get predictions from custom trained models