Run JAX code on TPU slices

Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project.

After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU slice. TPU slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU slices; for more in-depth information, see Using JAX in multi-host and multi-process environments.

Required roles

To get the permissions that you need to create a TPU and connect to it using SSH, ask your administrator to grant you the following IAM roles on your project:

For more information about granting roles, see Manage access to projects, folders, and organizations.

You might also be able to get the required permissions through custom roles or other predefined roles.

Create a Cloud TPU slice

  1. Create some environment variables:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Environment variable descriptions

    • PROJECT_ID: Your Google Cloud project ID. Use an existing project or create a new one.
    • TPU_NAME: The name of the TPU.
    • ZONE: The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones.
    • ACCELERATOR_TYPE: The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
    • RUNTIME_VERSION: The Cloud TPU software version.

  2. Create a TPU slice using the gcloud command. For example, to create a v5litepod-32 slice use the following command:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=