Fine-tuning Flan-T5 Base and online deployment in Vertex AI

Rafa Sanchez
Google Cloud - Community
5 min readFeb 28, 2023
Fig.1 How Flan works. Source. Google blog

Google has released the checkpoints of several open-source LLM models including BERT, T5 or UL2. Amongst the most interesting in terms of performance and results is Flan-T5, a variant of T5 that generalises better and outperforms T5 in many NLP tasks.

This post shows how to fine-tune a Flan-T5-Base model for the SAMSum dataset (summary of conversations in English) using Vertex AI. In particular, you will use Vertex AI Training with a 1xA100 GPU (40 GB HBM) for fine-tuning, and Vertex AI Prediction for online predictions.

The dataset: SAMSum

The SAMSum dataset contains about 16k messenger-like conversations with summaries. Conversations were created and written down by linguists fluent in English. This use case is similar for example to the conversation summaries feature available in Google Chat.

The model: Flan-T5 Base

Flan-T5 is a variant that outperforms T5 on a large variety of tasks. It is multilingual and uses instruction fine-tuning that, in general, improves the performance and usability of pretrained language models, particularly T5.

Fig.2 T5 model. Source: Google blog

Flan-T5 has public checkpoints for different sizes. This code sample will use the google/flan-t5-base version.

Fine-tuning

Using libraries from Hugging Face, you will fine-tune a Flan-T5-Base model on the SAMSum dataset for English conversations. You will use the google/flan-t5-base, that will be fine-tuned and moved to Vertex AI Model registry.

The model is fine-tuned on Vertex AI with a 1xA100 NVIDIA GPU (40 GB HBM). The code launches a Training pipeline, a type of Vertex AI job, which executes the following three steps sequentially: create a Vertex AI Managed Dataset (not created here because it ‘s downloaded from Hugging Face), run the workload in Vertex AI Training (fine-tuning in our case), and upload the model to Vertex AI Model Registry.

The fine-tuning should take around 45 minutes to complete.

job = aiplatform.CustomTrainingJob(
display_name="flan_t5_base_finetuning_gpu_tensorboard",
script_path="flant5base_trainer.py",
requirements=["py7zr==0.20.4",
"nltk==3.7",
"evaluate==0.4.0",
"rouge_score==0.1.2",
"transformers==4.25.1",
"tensorboard==2.11.2",
"datasets==2.9.0",
"google-cloud-storage==2.7.0"],
container_uri="europe-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-10:latest",
model_serving_container_image_uri="europe-docker.pkg.dev/vertex-ai/prediction/pytorch-gpu.1-10:latest",
)

model = job.run(
model_display_name='flan-t5-base-finetuning-gpu-tensorboard',
replica_count=1,
service_account = SERVICE_ACCOUNT,
tensorboard = TENSORBOARD_RESOURCE_NAME,
machine_type="a2-highgpu-1g",
accelerator_type="NVIDIA_TESLA_A100",
accelerator_count = 1,
)

Uvicorn

A Custom Container image for predictions is required. Custom Container image requires that the container runs an HTTP server. Specifically, the container must listen and respond to liveness checks, health checks, and prediction requests.

This repo uses FastAPI and Uvicorn to implement the HTTP server. The HTTP server listens for requests on 0.0.0.0. Uvicorn is an ASGI web server implementation for Python. Uvicorn currently supports HTTP/1.1 and WebSockets. Here is a docker image with Uvicorn managed by Gunicorn for high-performance FastAPI web applications in Python 3.6+ with performance auto-tuning. An uvicorn server is launched with:

uvicorn main:app --host 0.0.0.0 --port 8080

Export model from Vertex AI Model Registry

After fine-tuning, the model is stored in Vertex AI Model Registry. Since you are going to embed the model in a custom Vertex Prediction Container, you must export the model and place it in the local Prediction Container directory predict/model-output-flan-t5-base directory (for example, with gcloud storage cp). The predict/ directory contains the Dockerfile to generate the Custom Prediction Container image.

The model must be available in the predict/model-output-flan-t5-base directory, with a similar content like this, before building the custom container prediction image. Note both model and tokenizer must be included:

config.json
pytorch_model.bin
source.spm
special_tokens_map.json
target.spm
tokenizer_config.json
training_args.bin
vocab.json

Build Custom Prediction Container image and upload model to Vertex AI Model Registry

Push docker image to Artifact Registry:

gcloud auth configure-docker europe-west4-docker.pkg.dev
gcloud builds submit --tag europe-west4-docker.pkg.dev/argolis-rafaelsanchez-ml-dev/ml-pipelines-repo/finetuning_flan_t5_base

Next step is to upload and deploy the model to Vertex AI Prediction. You will use the previous image with this sample code. Note we are not using GPU for prediction:

DEPLOY_IMAGE = 'europe-west4-docker.pkg.dev/argolis-rafaelsanchez-ml-dev/ml-pipelines-repo/finetuning_flan_t5_base' 
HEALTH_ROUTE = "/health"
PREDICT_ROUTE = "/predict"
SERVING_CONTAINER_PORTS = [8080]

model = aiplatform.Model.upload(
display_name=f'custom-finetuning_flan_t5_base',
description=f'Finetuned Flan T5 model with Uvicorn and FastAPI',
serving_container_image_uri=DEPLOY_IMAGE,
serving_container_predict_route=PREDICT_ROUTE,
serving_container_health_route=HEALTH_ROUTE,
serving_container_ports=SERVING_CONTAINER_PORTS,
)

# Retrieve a Model on Vertex
model = aiplatform.Model(model.resource_name)

# Deploy model
endpoint = model.deploy(
machine_type='n1-standard-4',
sync=False
)
endpoint.wait()

Online predictions

Predict using the Vertex AI Python SDK:

# Retrieve an Endpoint on Vertex
print(endpoint.predict([[sample["dialogue"]]]))
# Output:
# Prediction(predictions=[[["Patti's cat is fine. Patti will pick her up later. Patti will fetch the cage after work."]]],
# deployed_model_id='4495568794441220096', model_version_id='1', model_resource_name='projects/989788194604/locations/europe-west4/models/2128205910630203392', explanations=None)

Predict using Vertex AI REST API:

curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" \
https://europe-west4-aiplatform.googleapis.com/v1/projects/989788194604/locations/europe-west4/endpoints/7028544517674369024:predict \
\
-d "{\"instances\": [\"
Greg: Hi Mum, how's the cat doing?
Patti: I just rang the vets, she's fine!
Greg: Thank God, been worrying about her all day!
Patti: They said I can pick her up later. I'll pop home and fetch the cage after work. Should be there at 5ish.
Greg: Good, see you at home, bye!
\"]}"
# Output
{
"predictions": [
[
[
"Patti will pick up the cat later. Patti will pop home and fetch the cage after work."
]
]
],
"deployedModelId": "4495568794441220096",
"model": "projects/989788194604/locations/europe-west4/models/2128205910630203392",
"modelDisplayName": "custom-finetuning-flan-t5-base",
"modelVersionId": "1"
}

Summary and acknowledgements

This post summarizes how to fine-tune and deploy a Flan-T5 Base model to an online endpoint in Vertex AI.

It should be noted that this example does not include model parallelism or distributed training, since the model could be considered as “small” (approx. 250M parameters). However, for bigger versions of Flan-T5, like google/flan-t5-xxl, distributed training will be required.

You can find the full code in this repo.

I would like to thank Camus Ma for comments and contributions to this post.

References

[1] Google research blog: The Flan Collection: Advancing open source methods for instruction tuning
[2] Flan-T5 paper
[3] Phil Schmid blog: Fine-tune FLAN-T5 for chat & dialogue summarization
[4] Vertex AI Training: Create a model using a custom training
[5] Vertex AI Prediction: Get predictions from custom trained models

--

--

Rafa Sanchez
Google Cloud - Community

I'm Rafa, Machine Learning specialist working @GoogleCloud. Ph.D. and Lecturer at the @uc3m University about IoT and on-device ML.