Using TPUs for fine-tuning and deploying LLMs via dstack

Andrey Cheptsov
Google Cloud - Community
4 min readSep 16, 2024

dstack is an open-source container orchestrator, designed to simplify development and deployment of AI on any cloud or on-prem. It supports NVIDIA, AMD, and TPU out of the box. If you’re using or planning to use TPUs with Google Cloud, you can now do so via dstack.

Before you can use dstack with your GCP account, you have to configure the gcp backend via ~/.dstack/server/config.yml:

projects:
- name: main
backends:
- type: gcp
project_id: gcp-project-id
creds:
type: default

Once backends are configured, proceed and start the server:

pip install "dstack[all]" -U
dstack server

Now, you can run dev environments, tasks, and services via dstack. Read below to find out how to use TPUs with dstack for fine-tuning and deploying LLMs, leveraging open-source tools like Hugging Face’s Optimum TPU and vLLM.

Deployment

You can use any serving framework, such as vLLM, TGI. Here’s an example of a service that deploys Llama 3.1 8B using Optimum TPU and vLLM .

Optimum TPU

type: service
name: llama31-service-optimum-tpu

image: dstackai/optimum-tpu:llama31
env:
- HUGGING_FACE_HUB_TOKEN
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- MAX_TOTAL_TOKENS=4096
- MAX_BATCH_PREFILL_TOKENS=4095
commands:
- text-generation-launcher --port 8000
port: 8000

spot_policy: auto
resources:
gpu: v5litepod-4

model:
format: tgi
type: chat
name: meta-llama/Meta-Llama-3.1-8B-Instruct

Once the pull request is merged, the official Docker image can be used instead of dstackai/optimum-tpu:llama31.

vLLM

type: service
name: llama31-service-vllm-tpu

env:
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- HUGGING_FACE_HUB_TOKEN
- DATE=20240828
- TORCH_VERSION=2.5.0
- VLLM_TARGET_DEVICE=tpu
- MAX_MODEL_LEN=4096
commands:
- pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
- git clone https://github.com/vllm-project/vllm.git
- cd vllm
- pip install -r requirements-tpu.txt
- apt-get install -y libopenblas-base libopenmpi-dev libomp-dev
- python setup.py develop
- vllm serve $MODEL_ID
--tensor-parallel-size 4
--max-model-len $MAX_MODEL_LEN
--port 8000
port:
- 8000

spot_policy: auto
resources:
gpu: v5litepod-4

model:
format: openai
type: chat
name: meta-llama/Meta-Llama-3.1-8B-InstructYa

If you specify model when running a service, dstack will automatically register the model on the gateway's global endpoint and allow you to use it for chat via the control plane UI.

Memory requirements

Below are the approximate memory requirements for serving LLMs with their corresponding TPUs.

Note, v5litepod is optimized for serving transformer-based models. Each core is equipped with 16GB of memory.

Supported frameworks

Running a configuration

Once the configuration is ready, run dstack apply -f <configuration file>, and dstack will automatically provision the cloud resources and run the configuration.

Fine-tuning

Below is an example of fine-tuning Llama 3.1 8B using Optimum TPU and the Abirate/english_quotes dataset.

type: task
name: optimum-tpu-llama-train

python: "3.11"

env:
- HUGGING_FACE_HUB_TOKEN
commands:
- git clone -b add_llama_31_support https://github.com/dstackai/optimum-tpu.git
- mkdir -p optimum-tpu/examples/custom/
- cp examples/fine-tuning/optimum-tpu/llama31/train.py optimum-tpu/examples/custom/train.py
- cp examples/fine-tuning/optimum-tpu/llama31/config.yaml optimum-tpu/examples/custom/config.yaml
- cd optimum-tpu
- pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install datasets evaluate
- pip install accelerate -U
- pip install peft
- python examples/custom/train.py examples/custom/config.yaml


resources:
gpu: v5litepod-8

Memory requirements

Below are the approximate memory requirements for fine-tuning LLMs with their corresponding TPUs.

Note, v5litepod is optimized for fine-tuning transformer-based models. Each core is equipped with 16GB of memory.

Supported frameworks

Currently, maximum 8 TPU cores can be specified, so the maximum supported values are v2-8, v3-8, v4-8, v5litepod-8, and v5e-8. Multi-host TPU support, allowing for larger numbers of cores, is coming soon.

What’s next

The source-code of this example can be found in examples/deployment/optimum-tpu and examples/fine-tuning/optimum-tpu.

To learn more about dstack, check out its GitHub repo and documentation.

--

--