Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Preserve parameter sharding with output data sharding #4721

Merged
merged 4 commits into from
Mar 8, 2023

Conversation

yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Mar 3, 2023

This addresses the same problem as in #4696 with an alternative solution. We shard the replicated output while handling the computation results. This avoids post traversal pass to replace original data node with a sharded one, thus more efficient. Key changes include:

  • Introduce ShardingUtil::OutputHandler
  • Add XLAShardingTest.OutputHandler test for unit testing, test_optimizer_step_with_sharding checks the validity of the change with a simple e2e example already.
  • Add std::optional<xla::Shape> to ShardingSpec
  • Add std::optional<xla::OpSharding> to PjRtShardedData
  • Pass an additional std::vector<XLATensor::ShardingSpecPtr> param to XLAGraphExecutor::ScheduleSyncTensorsGraph, since the async function now calls ShardingUtil::OutputHandler
  • Introduce & call XLAGraphExecutor::CollectShardingSpecs before calling ScheduleSyncTensorsGraph
  • Introduce WrapDataShards and GetDataSharding APIs in ComputationClient.

@yeounoh yeounoh requested review from alanwaketan and JackCaoG March 3, 2023 23:52
@yeounoh yeounoh self-assigned this Mar 3, 2023
@yeounoh yeounoh marked this pull request as draft March 3, 2023 23:52
@yeounoh yeounoh force-pushed the new_param_sharding_fix branch from 54c51f1 to 420d701 Compare March 3, 2023 23:55
@yeounoh yeounoh force-pushed the new_param_sharding_fix branch 15 times, most recently from 3eac5e6 to f26b305 Compare March 4, 2023 04:02
@yeounoh yeounoh marked this pull request as ready for review March 4, 2023 04:20
@yeounoh yeounoh force-pushed the new_param_sharding_fix branch 2 times, most recently from 5c3e631 to 0ddee73 Compare March 6, 2023 22:59
torch_xla/csrc/xla_graph_executor.cpp Outdated Show resolved Hide resolved
torch_xla/csrc/xla_sharding_util.cpp Outdated Show resolved Hide resolved
test/cpp/test_xla_sharding.cpp Outdated Show resolved Hide resolved
@JackCaoG
Copy link
Collaborator

JackCaoG commented Mar 7, 2023

OutputHandler seems to crash

@yeounoh yeounoh force-pushed the new_param_sharding_fix branch from e45ab94 to 5ba829f Compare March 7, 2023 21:04
@yeounoh yeounoh force-pushed the new_param_sharding_fix branch from 5ba829f to 26279e3 Compare March 7, 2023 21:04
@yeounoh
Copy link
Contributor Author

yeounoh commented Mar 7, 2023

OutputHandler seems to crash

Yea, we need at least 2 devices to create Hlo sharding.

2023-03-07 21:01:25.297118: F external/org_tensorflow/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc:54] Check failed: num_tiles > 1 (1 vs. 1)

Added the safeguard.

@yeounoh yeounoh requested review from steventk-g and jonb377 March 7, 2023 22:03
@yeounoh yeounoh changed the title Preserve parameter sharding with output data sharding [SPMD] Preserve parameter sharding with output data sharding Mar 7, 2023
@yeounoh yeounoh force-pushed the new_param_sharding_fix branch from 26279e3 to 8d83ef4 Compare March 7, 2023 22:55
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@@ -179,7 184,11 @@ class PjRtComputationClient : public ComputationClient {
}

void Assign(const Data& data) override {
XLA_ERROR() << __FUNCTION__ << " not supported.";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We can retry the simple MpDeviceLoader hack for SPMD once this lands, this was the blocker.

@yeounoh yeounoh merged commit e2abcaf into master Mar 8, 2023
mateuszlewko pushed a commit that referenced this pull request Mar 15, 2023
[SPMD] Persist tensor sharding with XLA sharding propagation
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Mar 29, 2023
…#4721)

[SPMD] Persist tensor sharding with XLA sharding propagation
ManfeiBai pushed a commit to ManfeiBai/PyTorchXLA that referenced this pull request Mar 29, 2023
…#4721)

[SPMD] Persist tensor sharding with XLA sharding propagation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants