Skip to content

Commit

Permalink
[Fix] zero optimizer w/ tensor parallel test (#167)
Browse files Browse the repository at this point in the history
## Title

- [Fix] zero optimizer w/ tensor parallel test

## Description

- ZeRO was not running in tensor parallel mode, so I fixed this by
switching to a model from `transformers`.

## Linked Issues

- N/A
  • Loading branch information
yhna940 authored Mar 28, 2023
1 parent c103c5b commit 3357dac
Showing 1 changed file with 24 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 7,18 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.utils import get_free_port, set_seed
from oslo.torch.nn.parallel.data_parallel.zero import ZeroRedundancyOptimizer
from torch.testing import assert_close
from oslo.torch.nn.parallel import TensorParallel
from transformers import AutoModelForSequenceClassification, AutoTokenizer

skip_if_dist_unavailable = pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="dist required"
)


class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


def assert_shard_close(
tensor: torch.Tensor,
shard: torch.Tensor,
Expand All @@ -40,14 29,14 @@ def assert_shard_close(
):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return assert_close(tensor, shard, rtol=rtol, atol=atol)
return torch.allclose(tensor, shard, rtol=rtol, atol=atol)
else:
dims_not_eq = torch.nonzero(
torch.tensor(tensor.shape) != torch.tensor(shard.shape)
)
if dims_not_eq.numel() == 1:
dim = dims_not_eq.item()
return assert_close(
return torch.allclose(
tensor.chunk(world_size, dim)[rank], shard, rtol=rtol, atol=atol
)
else:
Expand All @@ -58,46 47,53 @@ def run(parallel_context: ParallelContext):
local_rank = torch.distributed.get_rank()

# create model
model = MlpModel().cuda()
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
hybrid_model = TensorParallel(
copy.deepcopy(model), parallel_context=parallel_context
)
zero_model = model
oslo.ready(hybrid_model, parallel_context)
zero_model = model.cuda()

# create optimizer
hybrid_optimizer = ZeroRedundancyOptimizer(
torch.optim.Adam(hybrid_model.parameters(), lr=1),
torch.optim.Adam(hybrid_model.parameters(), lr=1e-2),
parallel_context=parallel_context,
overlap_communication=True,
partition_grad=True,
)
zero_optimizer = ZeroRedundancyOptimizer(
torch.optim.Adam(zero_model.parameters(), lr=1),
torch.optim.Adam(zero_model.parameters(), lr=1e-2),
parallel_context=parallel_context,
overlap_communication=True,
)

# create tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# create data
set_seed(2021 local_rank)
input_data = torch.randn(32, 128).cuda()
input_text = ["This is a sample text."] * 32
inputs = tokenizer(
input_text, return_tensors="pt", padding=True, truncation=True
).to("cuda")
labels = torch.randint(0, model.config.num_labels, (32,)).long().cuda()

# zero-dp forward
hybrid_output = hybrid_model(input_data)
zero_output = zero_model(input_data)
hybrid_output = hybrid_model(**inputs, labels=labels).loss
zero_output = zero_model(**inputs, labels=labels).loss

assert torch.allclose(hybrid_output, zero_output)

# zero-dp backward
hybrid_output.sum().float().backward()
zero_output.sum().float().backward()
hybrid_output.backward()
zero_output.backward()

# step
hybrid_optimizer.step()
zero_optimizer.step()

# check updated param
for hp, zp in zip(hybrid_model.parameters(), zero_model.parameters()):
assert torch.allclose(hp.data, zp.data)
assert assert_shard_close(
zp.data, hp.data, local_rank, torch.distributed.get_world_size()
)


def run_dist(rank, world_size):
Expand Down

0 comments on commit 3357dac

Please sign in to comment.