aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_metadata.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc70
1 files changed, 38 insertions, 32 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 39036e205e..94f5a3b273 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -88,6 +88,12 @@ std::vector<PassThrough> LocatePassThroughDomainLinks(
VLOG(2) << " " << instruction->ToString();
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ pass_through.emplace_back(nullptr, instruction);
+ VLOG(2) << "Found passthrough domain link:";
+ VLOG(2) << " <root>";
+ VLOG(2) << " " << instruction->ToString();
+ }
}
return pass_through;
}
@@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
tuple, 0));
gte->set_sharding(sharding);
- TF_RETURN_IF_ERROR(
- pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ if (pass_through.user != nullptr) {
+ TF_RETURN_IF_ERROR(
+ pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ } else {
+ pass_through.operand->parent()->set_root_instruction(gte);
+ }
}
return Status::OK();
}
@@ -235,21 +245,6 @@ StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
const HloSharding& sharding) {
- // Here is the place to call external sharding normalizers, which are
- // implemented in other modules (ie, spatial partitioning).
- // The signature of the external normalizer function should be something
- // like:
- //
- // StatusOr<bool> Normalizer(const DomainMetadata::Domain&,
- // const HloSharding& sharding);
- //
- // The function should return true if it has processed the domain
- // normalization, false if domain was not one recognized by it, or an error.
- // We will call the functions in order below, and fall back to local code if
- // none of the external normalizers acted on the domain.
- // External normalizers should not handle the cases that are already handled
- // locally.
-
// None of the external normalizers handled the domain sharding, try to see
// whether this is a single sharding first.
auto single_sharding = sharding.ExtractSingleSharding();
@@ -380,25 +375,36 @@ string ShardingMetadata::ToString() const {
return sharding_ != nullptr ? sharding_->ToString() : "{}";
}
-Status ShardingMetadata::NormalizeInstructions(
- const DomainMetadata::Domain& domain) const {
- if (sharding_ != nullptr) {
- VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":";
- TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_));
- TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_));
+/*static*/ StatusOr<const ShardingMetadata*>
+ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) {
+ if (metadata->Kind() != ShardingMetadata::KindName()) {
+ return Status(
+ tensorflow::error::INVALID_ARGUMENT,
+ "ShardingMetadata normalizer called with incorrect domain metadata");
}
- return Status::OK();
+ return static_cast<const ShardingMetadata*>(metadata);
}
-Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
- ExtractOriginalCommonSharding(domain.instructions));
- if (sharding != nullptr) {
- VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString()
- << ":";
- TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
+Status ShardingMetadata::NormalizeShardingDomain(
+ const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
+ if (metadata != nullptr) {
+ TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
+ ToShardingMetadata(metadata));
+ const HloSharding* sharding = sharding_metadata->sharding();
+ if (sharding != nullptr) {
+ VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":";
+ TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
+ TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
+ }
} else {
- VLOG(1) << "Unable to find common sharding";
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding,
+ ExtractOriginalCommonSharding(domain.instructions));
+ if (sharding != nullptr) {
+ VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
+ TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
+ } else {
+ VLOG(1) << "Unable to find common sharding";
+ }
}
return Status::OK();
}