-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPMD] Preserve parameter sharding with output data sharding #4721
Conversation
54c51f1
to
420d701
Compare
3eac5e6
to
f26b305
Compare
5c3e631
to
0ddee73
Compare
a90760e
to
e45ab94
Compare
|
e45ab94
to
5ba829f
Compare
5ba829f
to
26279e3
Compare
Yea, we need at least 2 devices to create Hlo sharding.
Added the safeguard. |
26279e3
to
8d83ef4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
@@ -179,7 184,11 @@ class PjRtComputationClient : public ComputationClient { | |||
} | |||
|
|||
void Assign(const Data& data) override { | |||
XLA_ERROR() << __FUNCTION__ << " not supported."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We can retry the simple MpDeviceLoader hack for SPMD once this lands, this was the blocker.
[SPMD] Persist tensor sharding with XLA sharding propagation
…#4721) [SPMD] Persist tensor sharding with XLA sharding propagation
…#4721) [SPMD] Persist tensor sharding with XLA sharding propagation
This addresses the same problem as in #4696 with an alternative solution. We shard the replicated output while handling the computation results. This avoids post traversal pass to replace original data node with a sharded one, thus more efficient. Key changes include:
ShardingUtil::OutputHandler
XLAShardingTest.OutputHandler
test for unit testing,test_optimizer_step_with_sharding
checks the validity of the change with a simple e2e example already.std::optional<xla::Shape>
toShardingSpec
std::optional<xla::OpSharding>
toPjRtShardedData
std::vector<XLATensor::ShardingSpecPtr>
param toXLAGraphExecutor::ScheduleSyncTensorsGraph
, since the async function now callsShardingUtil::OutputHandler
XLAGraphExecutor::CollectShardingSpecs
before callingScheduleSyncTensorsGraph
WrapDataShards
andGetDataSharding
APIs inComputationClient
.