Using TPUs for fine-tuning and deploying LLMs via dstack
dstack is an open-source container orchestrator, designed to simplify development and deployment of AI on any cloud or on-prem. It supports NVIDIA, AMD, and TPU out of the box. If you’re using or planning to use TPUs with Google Cloud, you can now do so via dstack
.
Before you can use dstack with your GCP account, you have to configure the gcp
backend via ~/.dstack/server/config.yml
:
projects:
- name: main
backends:
- type: gcp
project_id: gcp-project-id
creds:
type: default
Once backends are configured, proceed and start the server:
pip install "dstack[all]" -U
dstack server
Now, you can run dev environments, tasks, and services via dstack
. Read below to find out how to use TPUs with dstack for fine-tuning and deploying LLMs, leveraging open-source tools like Hugging Face’s Optimum TPU and vLLM.
Deployment
You can use any serving framework, such as vLLM, TGI. Here’s an example of a service that deploys Llama 3.1 8B using Optimum TPU and vLLM .
Optimum TPU
type: service
name: llama31-service-optimum-tpu
image: dstackai/optimum-tpu:llama31
env:
- HUGGING_FACE_HUB_TOKEN
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- MAX_TOTAL_TOKENS=4096
- MAX_BATCH_PREFILL_TOKENS=4095
commands:
- text-generation-launcher --port 8000
port: 8000
spot_policy: auto
resources:
gpu: v5litepod-4
model:
format: tgi
type: chat
name: meta-llama/Meta-Llama-3.1-8B-Instruct
Once the pull request is merged, the official Docker image can be used instead of dstackai/optimum-tpu:llama31
.
vLLM
type: service
name: llama31-service-vllm-tpu
env:
- MODEL_ID=meta-llama/Meta-Llama-3.1-8B-Instruct
- HUGGING_FACE_HUB_TOKEN
- DATE=20240828
- TORCH_VERSION=2.5.0
- VLLM_TARGET_DEVICE=tpu
- MAX_MODEL_LEN=4096
commands:
- pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp311-cp311-linux_x86_64.whl
- pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
- git clone https://github.com/vllm-project/vllm.git
- cd vllm
- pip install -r requirements-tpu.txt
- apt-get install -y libopenblas-base libopenmpi-dev libomp-dev
- python setup.py develop
- vllm serve $MODEL_ID
--tensor-parallel-size 4
--max-model-len $MAX_MODEL_LEN
--port 8000
port:
- 8000
spot_policy: auto
resources:
gpu: v5litepod-4
model:
format: openai
type: chat
name: meta-llama/Meta-Llama-3.1-8B-InstructYa
If you specify model
when running a service, dstack
will automatically register the model on the gateway's global endpoint and allow you to use it for chat via the control plane UI.
Memory requirements
Below are the approximate memory requirements for serving LLMs with their corresponding TPUs.
Note, v5litepod
is optimized for serving transformer-based models. Each core is equipped with 16GB of memory.
Supported frameworks
Running a configuration
Once the configuration is ready, run dstack apply -f <configuration file>
, and dstack
will automatically provision the cloud resources and run the configuration.
Fine-tuning
Below is an example of fine-tuning Llama 3.1 8B using Optimum TPU and the Abirate/english_quotes dataset.
type: task
name: optimum-tpu-llama-train
python: "3.11"
env:
- HUGGING_FACE_HUB_TOKEN
commands:
- git clone -b add_llama_31_support https://github.com/dstackai/optimum-tpu.git
- mkdir -p optimum-tpu/examples/custom/
- cp examples/fine-tuning/optimum-tpu/llama31/train.py optimum-tpu/examples/custom/train.py
- cp examples/fine-tuning/optimum-tpu/llama31/config.yaml optimum-tpu/examples/custom/config.yaml
- cd optimum-tpu
- pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
- pip install datasets evaluate
- pip install accelerate -U
- pip install peft
- python examples/custom/train.py examples/custom/config.yaml
resources:
gpu: v5litepod-8
Memory requirements
Below are the approximate memory requirements for fine-tuning LLMs with their corresponding TPUs.
Note, v5litepod
is optimized for fine-tuning transformer-based models. Each core is equipped with 16GB of memory.
Supported frameworks
Currently, maximum 8 TPU cores can be specified, so the maximum supported values are
v2-8
,v3-8
,v4-8
,v5litepod-8
, andv5e-8
. Multi-host TPU support, allowing for larger numbers of cores, is coming soon.
What’s next
The source-code of this example can be found in examples/deployment/optimum-tpu
and examples/fine-tuning/optimum-tpu
.
To learn more about dstack, check out its GitHub repo and documentation.