Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Mar 6, 2023
1 parent 275f312 commit 5c3e631
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 104 deletions.
50 changes: 50 additions & 0 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 239,55 @@ TEST_F(XLAShardingTest, InputHandler) {
at::kFloat));
}

TEST_F(XLAShardingTest, OutputHandler) {
if (xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") {
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
}

std::vector<std::string> devices =
xla::ComputationClient::Get()->GetLocalDevices();

// Prepare an input vecotr `outputs` with 2 arguments per device.
std::vector<std::vector<xla::ComputationClient::DataPtr>> outputs;
outputs.reserve(devices.size());
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
for (auto device : devices) {
outputs.push_back(
UnwrapXlaData(CreateTensorsData({tensor, tensor}, {device, device})));
}

xla::Shape tensor_shape =
CreateComputationShapeFromTensor(tensor, GetDefaultDevice());
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
xla::HloSharding::Tile1D(
CreateComputationShapeFromTensor(tensor, GetDefaultDevice()),
devices.size())
.ToProto(),
tensor_shape);
std::vector<XLATensor::ShardingSpecPtr> sharding_specs{sharding_spec,
sharding_spec};

// Shard a PjRtData into a PjRtShardedData.
std::vector<xla::ComputationClient::DataPtr> sharded_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/true);
EXPECT_EQ(sharded_outputs.size(), 2);
auto shards =
xla::ComputationClient::Get()->GetDataShards(sharded_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_TRUE(
!xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));

// Wrap sharded data into a PjRtShardedData with `devices.size()` shards.
std::vector<xla::ComputationClient::DataPtr> wrapped_outputs =
ShardingUtil::OutputHandler(outputs, sharding_specs,
/*replicated_output=*/false);
EXPECT_EQ(wrapped_outputs.size(), 2);
shards = xla::ComputationClient::Get()->GetDataShards(wrapped_outputs[0]);
EXPECT_EQ(shards.size(), devices.size());
EXPECT_TRUE(
xla::Shape::Equal().IgnoreLayout()(shards[0]->shape(), tensor_shape));
}

} // namespace cpp_test
} // namespace torch_xla
9 changes: 6 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 111,7 @@ def test_mark_step_with_sharding(self):
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, self._get_mesh((1, self.n_devices)),
(0, 1))
Expand All @@ -121,15 122,17 @@ def test_optimizer_step_with_sharding(self):
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(5):
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))

def test_inplace_add_with_sharding(self):
xt = torch.ones(2, 2).to(xm.xla_device())
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 254,7 @@ ComputationClient::DataPtr PjRtComputationClient::ReplicateShardedData(
<< ", shape=" << handle->shape() << ")";
xla::XlaBuilder b("ReplicateShardedData");
xla::Shape shape = sharded_data->shape();
b.SetSharding(sharded_data->GetSharding().value());
b.SetSharding(sharded_data->GetSharding());

// perform a simple identity calculation to reassemble the input as
// replicated output.
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla_client/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 202,10 @@ class PjRtComputationClient : public ComputationClient {
return true;
}

std::optional<xla::OpSharding> GetSharding() { return sharding; }
xla::OpSharding GetSharding() { return sharding; }

std::vector<std::shared_ptr<PjRtData>> shards;
std::optional<xla::OpSharding> sharding;
xla::OpSharding sharding;
};

struct PjRtComputation : public Computation {
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 66,6 @@ XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
if (impl == nullptr) {
return XLATensorPtr();
}
auto t = impl->tensor();
return impl->tensor();
}

Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 104,10 @@ void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) {

c10::optional<at::ScalarType> PromoteIntegralType(
at::ScalarType src_dtype, const c10::optional<at::ScalarType>& opt_dtype) {
return opt_dtype.has_value() ? opt_dtype.value()
: at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong
: opt_dtype;
return opt_dtype.has_value()
? opt_dtype.value()
: at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong
: opt_dtype;
}

bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) {
Expand Down Expand Up @@ -539,7 540,7 @@ at::Tensor XLANativeFunctions::_to_copy(
// Case 1: Materialize the tensor.
if (device && device->type() != c10::kXLA) {
XLA_CHECK(device->type() == c10::kCPU)
<< "only cpu device is supported in _to_copy." << std::endl;
<< "only cpu device is supported in _to_copy.";
auto self_tensor = bridge::GetXlaTensor(self);
auto eager_tensor = self_tensor->ToTensor(/*detached=*/true);

Expand Down
77 changes: 36 additions & 41 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,14 834,14 @@ void BuildProfilerSubmodule(py::module* m) {
py::class_<xla::profiler::ProfilerServer,
std::unique_ptr<xla::profiler::ProfilerServer>>
profiler_server_class(profiler, "ProfilerServer");
profiler.def(
"start_server",
[](int port) -> std::unique_ptr<xla::profiler::ProfilerServer> {
auto server = absl::make_unique<xla::profiler::ProfilerServer>();
server->Start(port);
return server;
},
py::arg("port"));
profiler.def("start_server",
[](int port) -> std::unique_ptr<xla::profiler::ProfilerServer> {
auto server =
absl::make_unique<xla::profiler::ProfilerServer>();
server->Start(port);
return server;
},
py::arg("port"));

profiler.def(
"trace",
Expand Down Expand Up @@ -1261,14 1261,13 @@ void InitXlaModuleBindings(py::module m) {
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def(
"_xla_wait_device_ops",
[](const std::vector<std::string>& devices) {
NoGilSection nogil;
XLAGraphExecutor::Get()->WaitDeviceOps(devices);
xla::ComputationClient::Get()->WaitDeviceOps(devices);
},
py::arg("devices"));
m.def("_xla_wait_device_ops",
[](const std::vector<std::string>& devices) {
NoGilSection nogil;
XLAGraphExecutor::Get()->WaitDeviceOps(devices);
xla::ComputationClient::Get()->WaitDeviceOps(devices);
},
py::arg("devices"));
m.def("_xla_counter_names", []() {
auto counter_names = torch::lazy::GetCounterNames();
auto xla_counter_names = xla::metrics::GetCounterNames();
Expand Down Expand Up @@ -1333,35 1332,32 @@ void InitXlaModuleBindings(py::module m) {
torch::lazy::MetricsArena::Get()->ResetMetrics();
xla::metrics::ClearMetrics();
});
m.def(
"_xla_tensors_report",
[](size_t nodes_threshold, const std::string& device) {
return GetLiveTensorsReport(nodes_threshold, device);
},
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_tensors_report",
[](size_t nodes_threshold, const std::string& device) {
return GetLiveTensorsReport(nodes_threshold, device);
},
py::arg("nodes_threshold") = 100, py::arg("device") = "");
m.def("_xla_memory_info", [](const std::string& device) -> py::object {
return GetMemoryInfo(device);
});
m.def(
"_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision
? xla::PrecisionConfig::HIGHEST
: xla::PrecisionConfig::DEFAULT);
},
py::arg("use_full_mat_mul_precision") = true);
m.def("_xla_set_use_full_mat_mul_precision",
[](bool use_full_mat_mul_precision) {
XlaHelpers::set_mat_mul_precision(
use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST
: xla::PrecisionConfig::DEFAULT);
},
py::arg("use_full_mat_mul_precision") = true);

py::class_<xla::util::RecordReader, std::shared_ptr<xla::util::RecordReader>>(
m, "RecordReader");
m.def(
"_xla_create_tfrecord_reader",
[](const std::string& path, const std::string& compression,
int64_t buffer_size) {
NoGilSection nogil;
return CreateRecordReader(path, compression, buffer_size);
},
py::arg("path"), py::arg("compression") = "",
py::arg("buffer_size") = 16 * 1024 * 1024);
m.def("_xla_create_tfrecord_reader",
[](const std::string& path, const std::string& compression,
int64_t buffer_size) {
NoGilSection nogil;
return CreateRecordReader(path, compression, buffer_size);
},
py::arg("path"), py::arg("compression") = "",
py::arg("buffer_size") = 16 * 1024 * 1024);
m.def(
"_xla_tfrecord_read",
[](const std::shared_ptr<xla::util::RecordReader>& reader) -> py::object {
Expand Down Expand Up @@ -1522,7 1518,6 @@ void InitXlaModuleBindings(py::module m) {
bool replicated = false, bool manual = false) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLAGraphExecutor::Get()->UnregisterTensor(xtensor->data().get());
xla::OpSharding sharding =
ShardingUtil::CreateOpSharding(tile_assignment, replicated, manual);
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
Expand Down Expand Up @@ -1568,7 1563,7 @@ void InitXlaModuleBindings(py::module m) {
xtensor->SetXlaData(xla_data);
xtensor->SetShardingSpec(*new_sharding_spec);

// Re-register sharded tensor
// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
m.def("_xla_clear_sharding", [](const at::Tensor& input) {
Expand Down
34 changes: 6 additions & 28 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 243,8 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding) {
<< ", must be cleared before applying a new one, "
<< sharding.sharding.DebugString();
}
if (CurrentIrValue() || CurrentDataHandle() || CurrentTensorData()) {
dynamic_cast<XlaNode*>(GetIrValue().node.get())
->SetSharding(sharding_spec()->sharding);
}
dynamic_cast<XlaNode*>(GetIrValue().node.get())
->SetSharding(sharding_spec()->sharding);
}
void XLATensor::ClearShardingSpec() {
data()->sharding = nullptr;
Expand Down Expand Up @@ -452,8 450,7 @@ at::Tensor XLATensor::ToTensor(bool detached) {
at::Tensor tensor;
c10::optional<at::Tensor> tensor_data = CurrentTensorData();
if (!tensor_data) {
if (CurrentIrValue()) XLAGraphExecutor::Get()->DeviceBarrier(GetDevice());

XLAGraphExecutor::Get()->DeviceBarrier(GetDevice());
// The GetXlaData() call will trigger an ApplyPendingGraph() if an IR
// XlaNode is available on the tensor.
std::vector<at::Tensor> tensors = XlaDataToTensors({GetXlaData()}, dtype());
Expand Down Expand Up @@ -575,42 572,23 @@ torch::lazy::Value XLATensor::MaybeCastIrValue(
XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value) const {
ir_value = MaybeCastIrValue(std::move(ir_value), GetDevice(),
/*logical_element_type=*/c10::nullopt);
bool try_propagate_sharding = ir_value && sharding_spec();
auto xtensor = Create(std::move(ir_value), GetDevice(), dtype_optional());
if (try_propagate_sharding) {
// TODO(yeounoh) remove this after functionalization fix is merged.
xtensor->SetShardingSpec(*sharding_spec());
}
return xtensor;
return Create(std::move(ir_value), GetDevice(), dtype_optional());
}

XLATensorPtr XLATensor::CreateFrom(
torch::lazy::Value ir_value,
c10::optional<at::ScalarType> logical_element_type_opt) const {
ir_value = MaybeCastIrValue(std::move(ir_value), GetDevice(),
logical_element_type_opt);
bool try_propagate_sharding = ir_value && sharding_spec();
auto xtensor =
Create(std::move(ir_value), GetDevice(), logical_element_type_opt);
if (try_propagate_sharding) {
// TODO(yeounoh) remove this after functionalization fix is merged.
xtensor->SetShardingSpec(*sharding_spec());
}
return xtensor;
return Create(std::move(ir_value), GetDevice(), logical_element_type_opt);
}

XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value,
const torch::lazy::BackendDevice& device,
at::ScalarType logical_element_type) const {
ir_value =
MaybeCastIrValue(std::move(ir_value), device, logical_element_type);
bool try_propagate_sharding = ir_value && sharding_spec();
auto xtensor = Create(std::move(ir_value), device, logical_element_type);
if (try_propagate_sharding) {
// TODO(yeounoh) remove this after functionalization fix is merged.
xtensor->SetShardingSpec(*sharding_spec());
}
return xtensor;
return Create(std::move(ir_value), device, logical_element_type);
}

void XLATensor::ApplyPendingGraph() {
Expand Down
10 changes: 2 additions & 8 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 848,7 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::SetTensorData(
tensor->data()->view = nullptr;
tensor->data()->tensor_data = c10::nullopt;
}
// TODO(yeounoh) we create sharded data placeholder, this will be used to
// Create sharded data placeholder, this will be used to
// hold the corresponding computation results.
if (tensor->sharding_spec()) {
auto sharding = tensor->sharding_spec();
Expand Down Expand Up @@ -949,7 949,7 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash)
<< " on devices: " << absl::StrJoin(devices, ",");
// TODO(yeounoh) OutputHandler creates sharded data for for sharded
// OutputHandler creates sharded data for for sharded
// tensor results. Both sharded and unsharded results should be
// "Assign"ed to the corresponding data placeholders.
std::vector<xla::ComputationClient::DataPtr> outputs =
Expand Down Expand Up @@ -1177,11 1177,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());

TF_VLOG(5) << "Initial program result shape: "
<< program_shape.result().ToString();
const std::vector<xla::Shape>& result_tuple_shapes =
program_shape.result().tuple_shapes();

bool should_wrap_parameter =
(program_shape.parameters_size() >= parameter_wrapping_threadshold) &&
using_pjrt;
Expand Down Expand Up @@ -1213,7 1208,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
<< coll.device << " done!";
TF_VLOG(5) << "Compiled program shape "
<< computations.front()->program_shape().ToString() << std::endl;

TF_VLOG(5)
<< "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
Expand Down
10 changes: 0 additions & 10 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 171,6 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const torch::lazy::BackendDevice& device);

private:
struct ShardingShapeData {
ShardingShapeData(xla::Shape full_shape, xla::Shape sharded_shape,
size_t tensor_index)
: full_shape(full_shape),
sharded_shape(sharded_shape),
tensor_index(tensor_index){};
xla::Shape full_shape;
xla::Shape sharded_shape;
size_t tensor_index;
};
// This is just to group results from compile(). Since our computation is
// different, we don't reuse the upstream CompilationResult.
struct CompilationResult {
Expand Down
Loading

0 comments on commit 5c3e631

Please sign in to comment.