Deploy TPU workloads on GKE Autopilot


This page describes how to accelerate machine learning (ML) workloads by using Cloud TPU accelerators (TPUs) in Google Kubernetes Engine (GKE) Autopilot clusters. This guidance can help you to select the correct libraries for your ML application frameworks, set up your TPU workloads to run optimally on GKE, and monitor your workloads after deployment.

This page is for Platform admins and operators, Data and AI specialists, and Application developers who want to prepare and run ML workloads on TPUs. To learn more about the common roles, responsibilities, and example tasks that we reference in Google Cloud content, see Common GKE Enterprise user roles and tasks.

Before reading this page, ensure that you're familiar with the following resources:

How TPUs work in Autopilot

To use TPUs in Autopilot workloads, you specify the following in your workload manifest:

  • The TPU version in the spec.nodeSelector field.
  • The TPU topology in the spec.nodeSelector field. The topology must be supported by the specified TPU version.
  • The number of TPU chips in the spec.containers.resources.requests and the spec.containers.resources.limits fields.

When you deploy the workload, GKE provisions nodes that have the requested TPU configuration and schedules your Pods on the nodes. GKE places each workload on its own node so that each Pod can access the full resources of the node with minimized risk of disruption.

TPUs in Autopilot are compatible with the following capabilities:

  1. Spot Pods
  2. Specific capacity reservations
  3. Extended run time Pods

Plan your TPU configuration

Before you use this guide to deploy TPU workloads, plan your TPU configuration based on your model and how much memory it requires. For details, see Plan your TPU configuration.

Pricing

For pricing information, see Autopilot pricing.

Before you begin

Before you start, make sure you have performed the following tasks:

  • Enable the Google Kubernetes Engine API.
  • Enable Google Kubernetes Engine API
  • If you want to use the Google Cloud CLI for this task, install and then initialize the gcloud CLI. If you previously installed the gcloud CLI, get the latest version by running gcloud components update.
  • Ensure that you have an Autopilot cluster running GKE version 1.29.2-gke.1521000 or later.
  • To use reserved TPUs, ensure that you have an existing specific capacity reservation. For instructions, see Consuming reserved zonal resources.

Ensure that you have TPU quota

The following sections help you ensure that you have enough quota when using TPUs in GKE.

To create TPU slice nodes, you must have TPU quota available unless you're using an existing capacity reservation. If you're using reserved TPUs, skip this section.

Creating TPU slice nodes in GKE requires Compute Engine API quota (compute.googleapis.com), not Cloud TPU API quota (tpu.googleapis.com). The name of the quota is different in regular Autopilot Pods and in Spot Pods.

To check the limit and current usage of your Compute Engine API quota for TPUs, follow these steps:

  1. Go to the Quotas page in the Google Cloud console:

    Go to Quotas

  2. In the Filter box, do the following:

    1. Select the Service property, enter Compute Engine API, and press Enter.

    2. Select the Type property and choose Quota.

    3. Select the Name property and enter the name of the quota based on the TPU version and value in the cloud.google.com/gke-tpu-accelerator node selector. For example, if you plan to create on-demand TPU v5e nodes whose value in the cloud.google.com/gke-tpu-accelerator node selector is tpu-v5-lite-podslice, enter TPU v5 Lite PodSlice chips.

      TPU version cloud.google.com/gke-tpu-accelerator Name of the quota for on-demand instances Name of the quota for Spot2 instances
      TPU v3 tpu-v3-device TPU v3 Device chips Preemptible TPU v3 Device chips
      TPU v3 tpu-v3-slice TPU v3 PodSlice chips Preemptible TPU v3 PodSlice chips
      TPU v4 tpu-v4-podslice TPU v4 PodSlice chips Preemptible TPU v4 PodSlice chips
      TPU v5e tpu-v5-lite-device TPU v5 Lite Device chips Preemptible TPU v5 Lite Device chips
      TPU v5e tpu-v5-lite-podslice TPU v5 Lite PodSlice chips Preemptible TPU v5 Lite PodSlice chips
      TPU v5p tpu-v5p-slice TPU v5p chips Preemptible TPU v5p chips
      TPU Trillium tpu-v6e-slice TPU v6e Slice chips Preemptible TPU v6e Lite PodSlice chips
    4. Select the Dimensions (e.g. locations) property and enter region: followed by the name of the region in which you plan to create TPUs in GKE. For example, enter region:us-west4 if you plan to create TPU slice nodes in the zone us-west4-a. TPU quota is regional, so all zones within the same region consume the same TPU quota.

If no quotas match the filter you entered, then the project has not been granted any of the specified quota for the region that you need, and you must request a TPU quota increase.

When a TPU reservation is created, both the limit and current use values for the corresponding quota increase by the number of chips in the TPU reservation. For example, when a reservation is created for 16 TPU v5e chips whose value in the cloud.google.com/gke-tpu-accelerator node selector is tpu-v5-lite-podslice, then both the Limit and Current usage for the TPU v5 Lite PodSlice chips quota in the relevant region increase by 16.

Quotas for additional GKE resources

You may need to increase the following GKE-related quotas in the regions where GKE creates your resources.

  • Persistent Disk SSD (GB) quota: The boot disk of each Kubernetes node requires 100GB by default. Therefore, this quota should be set at least as high as the product of the maximum number of GKE nodes you anticipate creating and 100GB (nodes * 100GB).
  • In-use IP addresses quota: Each Kubernetes node consumes one IP address. Therefore, this quota should be set at least as high as the maximum number of GKE nodes you anticipate creating.
  • Ensure that max-pods-per-node aligns with the subnet range: Each Kubernetes node uses secondary IP ranges for Pods. For example, max-pods-per-node of 32 requires 64 IP addresses which translates to a /26 subnet per node. Note that this range shouldn't be shared with any other cluster. To avoid exhausting the IP address range, use the --max-pods-per-node flag to limit the number of pods allowed to be scheduled on a node. The quota for max-pods-per-node should be set at least as high as the maximum number of GKE nodes you anticipate creating.

To request an increase in quota, see Request higher quota.

Options for provisioning TPUs in GKE

GKE Autopilot lets you use TPUs directly in individual workloads by using Kubernetes nodeSelectors.

Alternatively, you can request TPUs by using custom compute classes. Custom compute classes let platform administrators define a hierarchy of node configurations for GKE to prioritize during node scaling decisions, so that workloads run on your selected hardware.

For instructions, see the Centrally provision TPUs with custom compute classes section.

Prepare your TPU application

TPU workloads have the following preparation requirements.

  1. Frameworks like JAX, PyTorch, and TensorFlow access TPU VMs using the libtpu shared library. libtpu includes the XLA compiler, TPU runtime software, and the TPU driver. Each release of PyTorch and JAX requires a certain libtpu.so version. To use TPUs in GKE, ensure that you use the following versions:
    TPU type libtpu.so version
    TPU Trillium (v6e)
    tpu-v6e-slice
    TPU v5e
    tpu-v5-lite-podslice
    tpu-v5-lite-device
    TPU v5p
    tpu-v5p-slice
    • Recommended jax[tpu] version: 0.4.19 or later.
    • Recommended torchxla[tpuvm] version: suggested to use a nightly version build on October 23, 2023.
    TPU v4
    tpu-v4-podslice
    TPU v3
    tpu-v3-slice
    tpu-v3-device
  2. Set the following environment variables for the container requesting the TPU resources:
    • TPU_WORKER_ID: A unique integer for each Pod. This ID denotes a unique worker-id in the TPU slice. The supported values for this field range from zero to the number of Pods minus one.
    • TPU_WORKER_HOSTNAMES: A comma-separated list of TPU VM hostnames or IP addresses that need to communicate with each other within the slice. There should be a hostname or IP address for each TPU VM in the slice. The list of IP addresses or hostnames are ordered and zero indexed by the TPU_WORKER_ID.
    • GKE automatically injects these environment variables by using a mutating webhook when a Job is created with the completionMode: Indexed, subdomain, parallelism > 1, and requesting google.com/tpu properties. GKE adds a headless Service so that the DNS records are added for the Pods backing the Service.

After you complete the workload preparation, you can run a Job that uses TPUs.

Request TPUs in a workload

This section shows you how to create a Job that requests TPUs in Autopilot. In any workload that needs TPUs, you must specify the following:

  • Node selectors for the TPU version and topology
  • The number of TPU chips for a container in your workload

For a list of supported TPU versions, topologies, and the corresponding number of TPU chips and nodes in a slice, see Choose an Autopilot TPU configuration.

Considerations for TPU requests in workloads

Only one container in a Pod can use TPUs. The number of TPU chips that a container requests must be equal to the number of TPU chips attached to a node in the slice. For example, if you request TPU v5e (tpu-v5-lite-podslice) with a 2x4 topology, you can request any of the following:

  • 4 chips, which creates two multi-host nodes with 4 TPU chips each
  • 8 chips, which creates one single-host node with 8 TPU chips

As a best practice to maximize your cost efficiency, always consume all of the TPU in the slice that you request. If you request a multi-host slice of two nodes with 4 TPU chips each, you should be deploying a workload that runs on both nodes and consumes all 8 TPU chips in the slice.

Create a workload that requests TPUs

The following steps create a Job that requests TPUs. If you have workloads that run on multi-host TPU slices, you must also create a headless Service that selects your workload by name. This headless Service lets Pods on different nodes in the multi-host slice to communicate with each other by updating the Kubernetes DNS configuration to point at the Pods in the workload.

  1. Save the following manifest as tpu-autopilot.yaml:

    apiVersion: v1
    kind: Service
    metadata:
      name: headless-svc
    spec:
      clusterIP: None
      selector:
        job-name: tpu-job
    ---
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: tpu-job
    spec:
      backoffLimit: 0
      completions: 4
      parallelism: 4
      completionMode: Indexed
      template:
        spec:
          subdomain: headless-svc
          restartPolicy: Never
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: TPU_TYPE
            cloud.google.com/gke-tpu-topology: TOPOLOGY
          containers:
          - name: tpu-job
            image: python:3.10
            ports:
            - containerPort: 8471 # Default port using which TPU VMs communicate
            - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
            command:
            - bash
            - -c
            - |
              pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
              python -c 'import jax; print("TPU cores:", jax.device_count())'
            resources:
              requests:
                cpu: 10
                memory: 500Gi
                google.com/tpu: NUMBER_OF_CHIPS
              limits:
                cpu: 10
                memory: 500Gi
                google.com/tpu: NUMBER_OF_CHIPS
    

    Replace the following:

    • TPU_TYPE: the TPU type to use, like tpu-v4-podslice. Must be a value supported by GKE.
    • TOPOLOGY: the arrangement of TPU chips in the slice, like 2x2x4. Must be a supported topology for the selected TPU type.
    • NUMBER_OF_CHIPS: the number of TPU chips for the container to use. Must be the same value for limits and requests.
  2. Deploy the Job:

    kubectl create -f tpu-autopilot.yaml
    

When you create this Job, GKE automatically does the following:

  1. Provisions nodes to run the Pods. Depending on the TPU type, topology, and resource requests that you specified, these nodes are either single-host slices or multi-host slices.
  2. Adds taints to the Pods and tolerations to the nodes to prevent any of your other workloads from running on the same nodes as TPU workloads.

Create a workload that requests TPUs and collection scheduling

In TPU Trillium, you can use collection scheduling to group TPU slice nodes. Grouping these TPU slice nodes makes it easier to adjust the number of replicas to meet the workload demand. Google Cloud controls software updates to ensure that sufficient slices within the collection are always available to serve traffic.

To learn about the limitation of collection scheduling, see How collection scheduling works

Collection schedulling in single-host TPU slice nodes is available for Autopilot clusters in version 1.31.2-gke.1088000 and later. To create single-host TPU slice nodes and group it as a collection, add the cloud.google.com/gke-workload-type:HIGH_AVAILABILITY label in your workload specification.

For example, the following code block defines a collection with a single-host TPU slice:

  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
    cloud.google.com/gke-workload-type: HIGH_AVAILABILITY
  ...

Centrally provision TPUs with custom compute classes

To provision TPUs with a custom compute class, do the following:

  1. Ensure that your cluster has an available custom compute class that selects TPUs. To learn how to specify TPUs in custom compute classes, see TPU rules.

  2. Save the following manifest as tpu-job.yaml:

    apiVersion: v1
    kind: Service
    metadata:
      name: headless-svc
    spec:
      clusterIP: None
      selector:
        job-name: tpu-job
    ---
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: tpu-job
    spec:
      backoffLimit: 0
      completions: 4
      parallelism: 4
      completionMode: Indexed
      template:
        spec:
          subdomain: headless-svc
          restartPolicy: Never
          nodeSelector:
            cloud.google.com/compute-class: TPU_CLASS_NAME
          containers:
          - name: tpu-job
            image: python:3.10
            ports:
            - containerPort: 8471 # Default port using which TPU VMs communicate
            - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
            command:
            - bash
            - -c
            - |
              pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
              python -c 'import jax; print("TPU cores:", jax.device_count())'
            resources:
              requests:
                cpu: 10
                memory: 500Gi
                google.com/tpu: NUMBER_OF_CHIPS
              limits:
                cpu: 10
                memory: 500Gi
                google.com/tpu: NUMBER_OF_CHIPS
    

    Replace the following:

    • TPU_CLASS_NAME: the name of the existing custom compute class that specifies TPUs.
    • NUMBER_OF_CHIPS: the number of TPU chips for the container to use. Must be the same value for limits and requests, equal to the value in the tpu.count field in the selected custom compute class.
  3. Deploy the Job:

    kubectl create -f tpu-workload.yaml
    

When you create this Job, GKE automatically does the following:

  • Provisions nodes to run the Pods. Depending on the TPU type, topology, and resource requests that you specified, these nodes are either single-host slices or multi-host slices. Depending on the availability of TPU resources in the top priority, GKE might fall back to lower priorities to maximize obtainability.
  • Adds taints to the Pods and tolerations to the nodes to prevent any of your other workloads from running on the same nodes as TPU workloads.

To learn more, see About custom compute classes.

Example: Display the total TPU chips in a multi-host slice

The following workload returns the number of TPU chips across all of the nodes in a multi-host TPU slice. To create a multi-host slice, the workload has the following parameters:

  • TPU version: TPU v4
  • Topology: 2x2x4

This version and topology selection result in a multi-host slice.

  1. Save the following manifest as available-chips-multihost.yaml:
    apiVersion: v1
    kind: Service
    metadata:
      name: headless-svc
    spec:
      clusterIP: None
      selector:
        job-name: tpu-available-chips
    ---
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: tpu-available-chips
    spec:
      backoffLimit: 0
      completions: 4
      parallelism: 4
      completionMode: Indexed
      template:
        spec:
          subdomain: headless-svc
          restartPolicy: Never
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
            cloud.google.com/gke-tpu-topology: 2x2x4
          containers:
          - name: tpu-job
            image: python:3.10
            ports:
            - containerPort: 8471 # Default port using which TPU VMs communicate
            - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
            command:
            - bash
            - -c
            - |
              pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
              python -c 'import jax; print("TPU cores:", jax.device_count())'
            resources:
              requests:
                cpu: 10
                memory: 500Gi
                google.com/tpu: 4
              limits:
                cpu: 10
                memory: 500Gi
                google.com/tpu: 4
  2. Deploy the manifest:
    kubectl create -f available-chips-multihost.yaml
    

    GKE runs a TPU v4 slice with four VMs (multi-host TPU slice). The slice has 16 interconnected TPU chips.

  3. Verify that the Job created four Pods:
    kubectl get pods
    

    The output is similar to the following:

    NAME                       READY   STATUS      RESTARTS   AGE
    tpu-job-podslice-0-5cd8r   0/1     Completed   0          97s
    tpu-job-podslice-1-lqqxt   0/1     Completed   0          97s
    tpu-job-podslice-2-f6kwh   0/1     Completed   0          97s
    tpu-job-podslice-3-m8b5c   0/1     Completed   0          97s
    
  4. Get the logs of one of the Pods:
    kubectl logs POD_NAME
    

    Replace POD_NAME with the name of one of the created Pods. For example, tpu-job-podslice-0-5cd8r.

    The output is similar to the following:

    TPU cores: 16
    

Example: Display the TPU chips in a single node

The following workload is a static Pod that displays the number of TPU chips that are attached to a specific node. To create a single-host node, the workload has the following parameters:

  • TPU version: TPU v5e
  • Topology: 2x4

This version and topology selection result in a single-host slice.

  1. Save the following manifest as available-chips-singlehost.yaml:
    apiVersion: v1
    kind: Pod
    metadata:
      name: tpu-job-jax-v5
    spec:
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
        cloud.google.com/gke-tpu-topology: 2x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        command:
        - bash
        - -c
        - |
          pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 8
          limits:
            google.com/tpu: 8
  2. Deploy the manifest:
    kubectl create -f available-chips-singlehost.yaml
    

    GKE provisions nodes with eight single-host TPU slices that use TPU v5e. Each TPU node has eight TPU chips (single-host TPU slice).

  3. Get the logs of the Pod:
    kubectl logs tpu-job-jax-v5
    

    The output is similar to the following:

    Total TPU chips: 8
    

Observability and metrics

Dashboard

In the Kubernetes Clusters page in the Google Cloud console, the Observability tab displays the TPU observability metrics. For more information, see GKE observability metrics.

The TPU dashboard is populated only if you have system metrics enabled in your GKE cluster.

Runtime metrics

In GKE version 1.27.4-gke.900 or later, TPU workloads that use JAX version 0.4.14 or later and specify containerPort: 8431 export TPU utilization metrics as GKE system metrics. The following metrics are available in Cloud Monitoring to monitor your TPU workload's runtime performance:

  • Duty cycle: Percentage of time over the past sampling period (60 seconds) during which the TensorCores were actively processing on a TPU chip. Larger percentage means better TPU utilization.
  • Memory used: Amount of accelerator memory allocated in bytes. Sampled every 60 seconds.
  • Memory total: Total accelerator memory in bytes. Sampled every 60 seconds.

These metrics are located in the Kubernetes node (k8s_node) and Kubernetes container (k8s_container) schema.

Kubernetes container:

  • kubernetes.io/container/accelerator/duty_cycle
  • kubernetes.io/container/accelerator/memory_used
  • kubernetes.io/container/accelerator/memory_total

Kubernetes node:

  • kubernetes.io/node/accelerator/duty_cycle
  • kubernetes.io/node/accelerator/memory_used
  • kubernetes.io/node/accelerator/memory_total

Host metrics

In GKE version 1.28.1-gke.1066000 or later, VMs in a TPU slice export TPU utilization metrics as GKE system metrics. The following metrics are available in Cloud Monitoring to monitor your TPU host's performance:

  • TensorCore utilization: Current percentage of the TensorCore that is utilized. The TensorCore value equals the sum of the matrix-multiply units (MXUs) plus the vector unit. The TensorCore utilization value is the division of the TensorCore operations that were performed over the past sample period (60 seconds) by the supported number of TensorCore operations over the same period. Larger value means better utilization.
  • Memory Bandwidth utilization: Current percentage of the accelerator memory bandwidth that is being used. Computed by dividing the memory bandwidth used over a sample period (60s) by the maximum supported bandwidth over the same sample period.

These metrics are located in the Kubernetes node (k8s_node) and Kubernetes container (k8s_container) schema.

Kubernetes container:

  • kubernetes.io/container/accelerator/tensorcore_utilization
  • kubernetes.io/container/accelerator/memory_bandwidth_utilization

Kubernetes node:

  • kubernetes.io/container/node/tensorcore_utilization
  • kubernetes.io/container/node/memory_bandwidth_utilization

For more information, see Kubernetes metrics and GKE system metrics.

Logging

Logs emitted by containers running on GKE nodes, including TPU VMs, are collected by the GKE logging agent, sent to Logging, and are visible in Logging.

Recommendations for TPU workloads in Autopilot

The following recommendations might improve the efficiency of your TPU workloads:

  • Use extended run time Pods for a grace period of up to seven days before GKE terminates your Pods for scale-downs or node upgrades. You can use maintenance windows and exclusions with extended run time Pods to further delay automatic node upgrades.
  • Use capacity reservations to ensure that your workloads receive requested TPUs without being placed in a queue for availability.