Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Mar 7, 2023
1 parent 048c549 commit 5ba829f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
8 changes: 6 additions & 2 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 240,12 @@ TEST_F(XLAShardingTest, InputHandler) {
}

TEST_F(XLAShardingTest, OutputHandler) {
if (xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") {
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
if ((xla::sys_util::GetEnvString(xla::env::kEnvPjRtDevice, "") == "") ||
(xla::ComputationClient::Get()->GetLocalDevices().size() < 2)) {
GTEST_SKIP()
<< "`PJRT_DEVICE` is not set, with more than 2 local devices, ("
<< xla::ComputationClient::Get()->GetLocalDevices().size()
<< " local devices detected).";
}

std::vector<std::string> devices =
Expand Down
24 changes: 12 additions & 12 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,18 848,6 @@ std::vector<torch::lazy::BackendDataPtr> XLAGraphExecutor::SetTensorData(
tensor->data()->view = nullptr;
tensor->data()->tensor_data = c10::nullopt;
}
// Create sharded data placeholder, this will be used to
// hold the corresponding computation results.
if (tensor->sharding_spec()) {
auto sharding = tensor->sharding_spec();
if (!sharding->shape.has_value()) {
sharding->shape = tensor->shape();
}
handle = WrapXlaData(xla::ComputationClient::Get()->WrapDataShards(
{UnwrapXlaData(handle)}, GetVirtualDevice().toString(),
sharding->shape.value(), sharding->sharding));
tensor->data()->handle = handle;
}
tensors_data.emplace_back(std::move(handle));
}
return tensors_data;
Expand All @@ -885,6 873,18 @@ void XLAGraphExecutor::ExtractIRAndPrepareXlaData_(
torch::lazy::BackendDataPtr handle =
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder(
tensor_device.toString(), std::move(shape)));
// Create sharded data placeholder, this will be used to
// hold the corresponding computation results.
if (tensor->sharding_spec()) {
auto sharding = tensor->sharding_spec();
if (!sharding->shape.has_value()) {
sharding->shape = tensor->shape();
}
handle = WrapXlaData(xla::ComputationClient::Get()->WrapDataShards(
{UnwrapXlaData(handle)}, GetVirtualDevice().toString(),
sharding->shape.value(), sharding->sharding));
tensor->data()->handle = handle;
}
tensor_data_vec.push_back(handle);
if (tensor->CurrentDataHandle() == nullptr && config.force_ltc_data) {
tensor->AssignIrValue(torch::lazy::Value());
Expand Down

0 comments on commit 5ba829f

Please sign in to comment.