Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Preserve parameter sharding with output data sharding #4721

Merged
merged 4 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 160,9 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto());
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto());
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
EXPECT_TRUE(!ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
EXPECT_TRUE(!ShardingUtil::EqualShardingSpecs(tiled_2d, replicated));
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, replicated));
}

TEST_F(XLAShardingTest, CreateTensorsData) {
Expand Down Expand Up @@ -239,5 239,59 @@ TEST_F(XLAShardingTest, InputHandler) {
at::kFloat));
}

TEST_F(XLAShardingTest, OutputHandler) {
if ((xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") ||
(xla::ComputationClient::Get()->GetLocalDevices().size() < 2)) {
GTEST_SKIP()
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
<< xla::ComputationClient::Get()->GetLocalDevices().size()
<< " local devices detected).";
}

std::vector<std::string> devices =
xla::ComputationClient::Get()->GetLocalDevices();

// Prepare an input vecotr `outputs` with 2 arguments per device.
std::vector<std::vector<xla::ComputationClient::DataPtr>> outputs;
outputs.reserve(devices.size());
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
for (auto device : devices) {
outputs.push_back(
UnwrapXlaData(CreateTensorsData({tensor, tensor}, {device, device})));
}

xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, GetDefaultDevice()),
devices.size())
.ToProto(),
tensor_shape);
std::vector<XLATensor::ShardingSpecPtr> sharding_specs{sharding_spec,
sharding_spec};

// Shard a PjRtData into a PjRtShardedData.
std::vector<xla::ComputationClient::DataPtr> sharded_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/true);
EXPECT_EQ(sharded_outputs.size(), 2);
auto shards =
xla::ComputationClient::Get()->GetDataShards(sharded_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_FALSE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));

// Wrap sharded data into a PjRtShardedData with `devices.size()` shards.
std::vector<xla::ComputationClient::DataPtr> wrapped_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/false);
EXPECT_EQ(wrapped_outputs.size(), 2);
shards = xla::ComputationClient::Get()->GetDataShards(wrapped_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_TRUE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));
}

} // namespace cpp_test
} // namespace torch_xla
9 changes: 6 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 111,7 @@ def test_mark_step_with_sharding(self):
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)),
(0, 1))
Expand All @@ -121,15 122,17 @@ def test_optimizer_step_with_sharding(self):
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(5):
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))

def test_inplace_add_with_sharding(self):
xt = torch.ones(2, 2).to(xm.xla_device())
Expand Down
9 changes: 9 additions & 0 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 209,15 @@ class ComputationClient {
// wrapped inside a vector.
virtual std::vector<DataPtr> GetDataShards(DataPtr data) = 0;

// Returns wrapped data shards as PjRtShardedData.
virtual DataPtr WrapDataShards(const std::vector<DataPtr>& shards,
std::string device, xla::Shape shape,
xla::OpSharding sharding) = 0;

// Returns OpSharding attached to PjRtShardedData. The returned optional
// structure will be empty if there is no sharding, like with PjRtData.
virtual std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) = 0;

// Transfers local tensor values to the TPU devices and fetches the handles.
virtual std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) = 0;
Expand Down
34 changes: 32 additions & 2 deletions third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 143,29 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::GetDataShards(
return shards;
}

ComputationClient::DataPtr PjRtComputationClient::WrapDataShards(
const std::vector<DataPtr>& shards, std::string device, xla::Shape shape,
xla::OpSharding sharding) {
std::vector<std::shared_ptr<PjRtData>> pjrt_data_shards;
pjrt_data_shards.reserve(shards.size());
for (auto& shard : shards) {
XLA_CHECK(shard != nullptr);
auto pjrt_shard = dynamic_cast<PjRtData*>(shard.get());
pjrt_data_shards.push_back(std::make_shared<PjRtData>(
pjrt_shard->device(), pjrt_shard->shape(), pjrt_shard->buffer));
}
return std::make_shared<PjRtShardedData>(device, shape, pjrt_data_shards,
sharding);
}

std::optional<xla::OpSharding> PjRtComputationClient::GetDataSharding(
DataPtr handle) {
if (auto sharded_data = dynamic_cast<PjRtShardedData*>(handle.get())) {
return sharded_data->GetSharding();
}
return std::optional<xla::OpSharding>();
}

std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
absl::Span<const TensorSource> tensors) {
metrics::TimedSection timed(TransferToServerMetric());
Expand Down Expand Up @@ -281,7 304,6 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
tsl::profiler::TraceMeLevel::kInfo);
std::vector<xla::Literal> literals;
literals.reserve(handles.size());

int64_t total_size = 0;
for (auto handle : handles) {
// Use XLA replication to reassemble the sharded data. If input handle
Expand Down Expand Up @@ -332,6 354,11 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
if (instance.is_sharded) {
// TODO(yeounoh) multi-host, multi-slice configurations
compile_options.executable_build_options.set_use_spmd_partitioning(true);
// TODO(yeounoh) this is set to false by default, but explicitly set here
// to expose the knob for future reference. We can override the compiler's
// default behavior to further optimize parameter sharding in the future.
compile_options.executable_build_options
.set_allow_spmd_sharding_propagation_to_output({false});
yeounoh marked this conversation as resolved.
Show resolved Hide resolved
compile_options.executable_build_options.set_num_partitions(
client_->device_count());
compile_options.executable_build_options.set_num_replicas(1);
Expand Down Expand Up @@ -487,13 514,15 @@ PjRtComputationClient::ExecuteReplicated(

std::vector<std::vector<ComputationClient::DataPtr>> data_handles;
data_handles.reserve(results.size());
std::vector<size_t> dims(results.size());
for (int32_t i = 0; i < results.size(); i) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[i]);
XLA_CHECK(pjrt_device->IsAddressable())
<< pjrt_device->DebugString() << " is not addressable.";

std::vector<ComputationClient::DataPtr> datas;
datas.reserve(results[i].size());
dims[i] = results[i].size();
for (int32_t j = 0; j < results[i].size(); j) {
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(results[i][j]);
XLA_CHECK(pjrt_device == buffer->device())
Expand All @@ -507,7 536,8 @@ PjRtComputationClient::ExecuteReplicated(
data_handles.push_back(datas);
}

TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results";
TF_VLOG(1) << "Returning " << data_handles.size() << " sets of results "
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
<< "with dimensions [" << absl::StrJoin(dims, ",") << "].";
return data_handles;
}

Expand Down
11 changes: 10 additions & 1 deletion third_party/xla_client/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 25,11 @@ class PjRtComputationClient : public ComputationClient {

std::vector<DataPtr> GetDataShards(DataPtr data) override;

DataPtr WrapDataShards(const std::vector<DataPtr>& shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) override;

std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override;

std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) override;

Expand Down Expand Up @@ -179,7 184,11 @@ class PjRtComputationClient : public ComputationClient {
}

void Assign(const Data& data) override {
XLA_ERROR() << __FUNCTION__ << " not supported.";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We can retry the simple MpDeviceLoader hack for SPMD once this lands, this was the blocker.

const PjRtShardedData& pjrt_sharded_data =
dynamic_cast<const PjRtShardedData&>(data);
if (&pjrt_sharded_data != this) {
shards = std::move(pjrt_sharded_data.shards);
}
}

bool HasValue() const override {
Expand Down
10 changes: 9 additions & 1 deletion third_party/xla_client/xrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 248,18 @@ class XrtComputationClient : public ComputationClient {
std::vector<xla::util::ExceptionCleanup> LockAsyncDatas(
absl::Span<const DataPtr> datas) override;

std::vector<DataPtr> GetDataShards(DataPtr data) override {
std::vector<DataPtr> GetDataShards(DataPtr data) override { return {data}; }

DataPtr WrapDataShards(const std::vector<DataPtr>& shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override {
// Returns an empty sharding result, since XRT does not support sharding.
return std::optional<xla::OpSharding>();
}

std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) override;

Expand Down
11 changes: 8 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,11 1517,13 @@ void InitXlaModuleBindings(py::module m) {
const py::list& tile_assignment,
bool replicated = false, bool manual = false) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xla::OpSharding sharding =
ShardingUtil::CreateOpSharding(tile_assignment, replicated, manual);
auto new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type())));

at::Tensor cpu_tensor;
if (xla::sys_util::GetEnvBool("XLA_USE_SPMD", false) &&
Expand Down Expand Up @@ -1560,6 1562,9 @@ void InitXlaModuleBindings(py::module m) {
std::vector<std::string>{GetVirtualDevice().toString()})[0];
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 11,13 @@ namespace torch_xla {
DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data)
: XlaNode(xla_device_data, UnwrapXlaData(data)->shape(), /*num_outputs=*/1,
/*hash_seed=*/(uint32_t)101),
data_(std::move(data)) {}
data_(std::move(data)) {
std::optional<xla::OpSharding> op_sharding =
xla::ComputationClient::Get()->GetDataSharding(UnwrapXlaData(data_));
if (op_sharding.has_value()) {
SetSharding(op_sharding.value());
}
}

std::string DeviceData::ToString() const {
std::stringstream ss;
Expand Down
19 changes: 6 additions & 13 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 118,8 @@ XLATensor::XLATensor(torch::lazy::Value ir_value,
if (CurrentIrValue()) {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node->GetSharding()) {
ShardingSpec sharding = ShardingSpec{*xla_node->GetSharding()};
ShardingSpec sharding =
ShardingSpec{*xla_node->GetSharding(), xla_node->xla_shape()};
SetShardingSpec(sharding);
}
}
Expand Down Expand Up @@ -262,7 263,8 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const {
auto* xla_node = dynamic_cast<XlaNode*>(ir_value.node.get());
if (xla_node->GetSharding()) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(
*sharding, ShardingSpec{*xla_node->GetSharding()}));
*sharding,
ShardingSpec{*xla_node->GetSharding(), xla_node->xla_shape()}));
}
}
return sharding;
Expand Down Expand Up @@ -316,17 318,8 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
}

void XLATensor::AssignIrValue(torch::lazy::Value ir_value) const {
if (ir_value) {
std::string debug_str = ir_value->ToString();
auto sharding = dynamic_cast<XlaNode*>(ir_value.node.get())->GetSharding();
if (sharding) {
debug_str = " with sharding " sharding->DebugString();
}
TF_VLOG(5) << "Assign IR value " << debug_str;
} else {
TF_VLOG(5) << "Assign empty IR value";
}

TF_VLOG(5) << "Assign IR value: "
<< (ir_value ? ir_value->ToString() : "empty node");
data()->ir_value = std::move(ir_value);
data()->generation = 1;
}
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 267,12 @@ class XLATensor : public torch::lazy::LazyTensor {
// HloSharding for replication, manual and tile shardings.
struct ShardingSpec {
ShardingSpec(const xla::OpSharding& sharding) : sharding(sharding) {}
ShardingSpec(const xla::OpSharding& sharding, const xla::Shape& shape)
: sharding(sharding), shape(shape) {}

xla::OpSharding sharding;
// Optional source tensor shape unpartitioned.
std::optional<xla::Shape> shape;
yeounoh marked this conversation as resolved.
Show resolved Hide resolved
};

// Annotate the IR value with ShardingSpec.
Expand Down
Loading