Skip to content
forked from pytorch/xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)

License

Notifications You must be signed in to change notification settings

rohan-varma/xla

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch/XLA

Current CI status: GitHub Actions status

Note: PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package.

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Take a look at one of our Kaggle notebooks to get started:

Getting Started

To install PyTorch/XLA a new VM:

pip install torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(index):
   ...

   # Move the model paramters to your XLA device
   model.to(xm.xla_device())
 
   # MpDeviceLoader preloads data to the XLA device
   xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
   for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
-    optimizer.step()
 
     # `xm.optimizer_step` combines gradients across replicas
     xm.optimizer_step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
   # xmp.spawn automatically selects the correct world size
   xmp.spawn(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
 import torch_xla.core.xla_model as xm
 import torch_xla.distributed.parallel_loader as pl
 import torch_xla.distributed.xla_multiprocessing as xmp
 import torch_xla.distributed.xla_backend

 def _mp_fn(rank, world_size):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
   # Rank and world size are inferred from the XLA device runtime
   dist.init_process_group("xla", init_method='xla://')
 
   model.to(xm.xla_device())
   # `gradient_as_bucket_view=tpu` required for XLA
   ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])
   xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
   for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = ddp_model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
     optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
   xmp.spawn(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, GPU, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

Wheel

Version Cloud TPU VMs Wheel
2.0 (Python 3.8) https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
nightly >= 2023/04/25 (Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly >= 2023/04/25 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl
older versions
Version Cloud TPU VMs Wheel
1.13 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.13-cp38-cp38-linux_x86_64.whl
1.12 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.12-cp38-cp38-linux_x86_64.whl
1.11 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.11-cp38-cp38-linux_x86_64.whl
1.10 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-1.10-cp38-cp38-linux_x86_64.whl
nightly <= 2023/04/25 https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Note: For TPU Pod customers using XRT (our legacy runtime), we have custom wheels for torch, torchvision, and torch_xla at https://storage.googleapis.com/tpu-pytorch/wheels/xrt.

Package Cloud TPU VMs Wheel (XRT on Pod, Legacy Only)
torch_xla https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
torch https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torch-2.0-cp38-cp38-linux_x86_64.whl
torchvision https://storage.googleapis.com/tpu-pytorch/wheels/xrt/torchvision-2.0-cp38-cp38-linux_x86_64.whl

Version GPU Wheel Python 3.8
2.0 CUDA 11.8 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
2.0 CUDA 11.7 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/117/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
1.13 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp38-cp38-linux_x86_64.whl
nightly CUDA 12.0 >= 2023/06/27 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly CUDA 11.8 <= 2023/04/25 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/118/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
nightly CUDA 11.8 >= 2023/04/25 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Version GPU Wheel Python 3.7
1.13 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.13-cp37-cp37m-linux_x86_64.whl
1.12 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
1.11 https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-1.11-cp37-cp37m-linux_x86_64.whl
nightly https://storage.googleapis.com/tpu-pytorch/wheels/cuda/112/torch_xla-nightly-cp37-cp37-linux_x86_64.whl

Version Colab TPU Wheel
2.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl

You can also add yyyymmdd after torch_xla-nightly to get the nightly wheel of a specified date. To get the companion pytorch and torchvision nightly wheel, replace the torch_xla with torch or torchvision on above wheel links.

Installing libtpu (before PyTorch/XLA 2.0)

For PyTorch/XLA release r2.0 and older and when developing PyTorch/XLA, install the libtpu pip package with the following command:

pip3 install torch_xla[tpuvm]

This is only required on Cloud TPU VMs.

Docker

Version Cloud TPU VMs Docker
2.0 gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm
1.13 gcr.io/tpu-pytorch/xla:r1.13_3.8_tpuvm
nightly python 3.10 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm
nightly python 3.8 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm
nightly python 3.10(>= 2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_YYYYMMDD
nightly python 3.8(>= 2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_YYYYMMDD
nightly at date(< 2023/04/25) gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm_YYYYMMDD

Version GPU CUDA 12.0 Python 3.8 Docker
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0
nightly at date(>=2023/06/27) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.0_YYYYMMDD

Version GPU CUDA 11.8 Python 3.8 Docker
2.0 gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.8
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8
nightly at date(>=2023/04/25) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_YYYYMMDD
nightly at date(<2023/04/25) gcr.io/tpu-pytorch/xla:nightly_3.8_cuda_11.8_YYYYMMDD

Version GPU CUDA 11.7 Python 3.8 Docker
2.0 gcr.io/tpu-pytorch/xla:r2.0_3.8_cuda_11.7

Version GPU CUDA 11.2 Python 3.8 Docker
1.13 gcr.io/tpu-pytorch/xla:r1.13_3.8_cuda_11.2

Version GPU CUDA 11.2 Python 3.7 Docker
1.13 gcr.io/tpu-pytorch/xla:r1.13_3.7_cuda_11.2
1.12 gcr.io/tpu-pytorch/xla:r1.12_3.7_cuda_11.2

To run on compute instances with GPUs.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to [email protected]. For questions directed at Google, please send an email to [email protected]. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

About

Enabling PyTorch on XLA Devices (e.g. Google TPU)

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Jupyter Notebook 56.5%
  • C 29.4%
  • Python 12.7%
  • Shell 0.6%
  • HCL 0.4%
  • Starlark 0.3%
  • Other 0.1%