Fine-tune Google PaLM 2 with Scikit-LLM

Iryna Kondrashchenko
5 min readJul 4, 2023

--

The latest update of Scikit-LLM brings the possibility of fine-tuning PaLM 2 — the most advanced LLM from Google.

Introduction

Recently, Google opened the access to PaLM 2, their most advanced large language model. Similarly to other models, a pre-trained PaLM 2 can be used for a variety of language tasks: classification, summarization, question answering, etc. In addition to that, it can be fine-tuned on custom data right in the Google Cloud for enhanced predictive power.

In this article we will provide a detailed guide on how to fine-tune Google PaLM 2 with Scikit-LLM using the newly added estimators. This is arguably the easiest method as it provides a familiar scikit-learn compatible API.

If you are not familiar with Scikit-LLM, check out my previous article about it:

Google Cloud Set Up

Currently, the only way to use PaLM 2 is via Vertex AI platform. Therefore, it is needed to create a Google account, log in to Google Cloud Console and create a Google Cloud project. After the project is created, select this project from a list of projects next to the Google Cloud logo (upper left corner). Then search for and select Vertex AI in the search bar.

Once you are on the Vertex AI main page, enable all recommended APIs. Please note that it can take several minutes.

Finally, we need to install a Google Cloud CLI on the local machine by following the steps from the official documentation, and set the application default credentials by running the following command:

gcloud auth application-default login

If everything worked properly, you should see the following page in the browser:

Using PaLM 2 with Scikit-LLM

In this section we will see how to use the available Scikit-LLM estimators to interact with PaLM 2 model in Vertex AI.

  1. Install Scikit-LLM
pip install scikit-llm

2. Set the project id, which can be found in a list of the projects.

from skllm.config import SKLLMConfig

SKLLMConfig.set_google_project("YOUR_PROJECT_ID")

3. To verify that Scikit-LLM is properly connected to the Google Cloud, let’s try to run a simple zero-shot classification task. In this setting we are not yet fine-tuning the model, so it should take just a few seconds to run.

from skllm.models.vertex.classification import ZeroShotVertexClassifier
from skllm.datasets import get_classification_dataset

X, y = get_classification_dataset()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

clf = ZeroShotVertexClassifier()
clf.fit(X_train, y_train)
labels = clf.predict(X_test)

print(labels)

If everything worked properly, you will see a list of predicted labels (positive/neutral/negative).

PaLM 2 Fine-Tuning with Scikit-LLM

In order to fine-tune the model we need to have 64 cores of the TPU v3 pod training resource. By default this quota is set to 0 cores and has to be increased. For that Go to Quotas -> filter quotas for “Restricted image training TPU V3 pod cores per region” and select “europe-west4” region (at the moment only this region is available for PaLM tuning).

Then click on “Edit Quotas”, set the limit to 64 and submit the request. Since all requests are processed manually by Google Support, it may take several days for a quota to be increased.

In principle, PaLM 2 tuning on TPU is subject to regular pricing for custom model training. However, at the time of writing this article, PaLM 2 tuning is still in the preview stage, so all charges are discounted by 100%. Please check the current status and pricing before training the model to avoid unexpected costs.

As soon as the quota was increased, it is possible to proceed with fine-tuning right away. Scikit-LLM provides two models: VertexClassifier and TunableVertexText2Text. The difference between them is that VertexClassifier is optimized for text classification tasks, whereas TunableVertexText2Text can be trained on the arbitrary input-output pairs.

Example 1: Classification

To use a VertexClassifier, we can reuse our zero-shot example. The only thing that we need to change is to replace ZeroShotVertexClassifier with VertexClassifier. The number of the update steps can be selected following the official recommendations from Google.

from skllm.models.vertex.classification.tunable import VertexClassifier
from skllm.datasets import get_classification_dataset

X, y = get_classification_dataset()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

clf = VertexClassifier(n_update_steps=100)
clf.fit(X_train, y_train)
labels = clf.predict(X_test)

print(labels)

Example 2: General Tuning

from skllm.models.vertex.text2text.tunable import TunableVertexText2Text

X = ["Tell us something about Scikit-LLM" for _ in range(20)]
y = ["Scikit-LLM is awesome" for _ in range(20)]

model = TunableVertexText2Text(n_update_steps=100)
model.fit(X, y)
labels = model.predict(["Tell us something about Scikit-LLM"])

print(labels[0])

# > Scikit-LLM is awesome

Behind the scenes, Scikit-LLM will automatically create a tuning pipeline in Vertex AI. As soon as the training starts, the link to the pipeline will be displayed in the terminal. You can use it for monitoring the state of the job.

Exemplary fine-tuning pipeline

Conclusion

In this article we explored one of the easiest options to fine-tune Google PaLM 2. By leveraging the power of Scikit-LLM, this can be done in just a few lines of code following the familiar scikit-learn interface.

If you want to learn more about LLM-related Python tools, check out my previous article about Dingo — a microframework for creating simple AI agents.

--

--