Run JAX code on TPU slices
Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project.
After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU slice. TPU slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU slices; for more in-depth information, see Using JAX in multi-host and multi-process environments.
Required roles
To get the permissions that you need to create a TPU and connect to it using SSH, ask your administrator to grant you the following IAM roles on your project:
- TPU Admin (
roles/tpu.admin) - Service Account User (
roles/iam.serviceAccountUser) - Compute Viewer (
roles/compute.viewer)
For more information about granting roles, see Manage access to projects, folders, and organizations.
You might also be able to get the required permissions through custom roles or other predefined roles.
Create a Cloud TPU slice
Create some environment variables:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Environment variable descriptions
PROJECT_ID: Your Google Cloud project ID. Use an existing project or create a new one.TPU_NAME: The name of the TPU.ZONE: The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones.ACCELERATOR_TYPE: The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.RUNTIME_VERSION: The Cloud TPU software version.
Create a TPU slice using the
gcloudcommand. For example, to create a v5litepod-32 slice use the following command:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=