aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-09 04:32:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 04:36:18 -0700
commit75a80aa3aa32fa12b74387b67f3d73aca532fc89 (patch)
tree6136b53f0a9850838b6540429e4041b4e4c78cef
parent0063183a62f69c2523a3982c70d72e231428fb60 (diff)
Fix domain removal when the root instruction is an empty domain
If a domain become empty because the various optimizations removed all instruction from it then we have to re-add some instruction to make sure the user supplied sharding is still respected. This is especially important for the root instruction as the user will expect the data to be available on the device they requested it. Before this CL we failed to insert the tuple->gte sequence into the empty domain due to a bug where we only considered cases where we have an exit domain what is not the case for the root instruction. PiperOrigin-RevId: 203744534
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc38
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc14
3 files changed, 55 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 957024a64a..9e096320db 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -62,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ auto domain = MakeUnique<DomainMetadata::Domain>();
+ domain->enter_domains.insert(instruction);
+ TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 3859e4cae6..00b2c860a7 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -436,6 +436,44 @@ ENTRY entry {
HloSharding::AssignDevice(0)}));
}
+TEST_F(HloDomainTest, EmptyRootDomain) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ %param = f32[1] parameter(0), sharding={maximal device=0}
+ %tuple = (f32[1]) tuple(%param),
+ sharding={maximal device=1}
+ ROOT %gte = f32[1] get-tuple-element(%tuple), index=0,
+ sharding={maximal device=1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
+ EXPECT_TRUE(isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module, "tuple", "param"));
+ EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple"));
+
+ // Remove %tuple and %gte (tuple simplification)
+ HloInstruction* gte = FindInstruction(module, "gte");
+ HloInstruction* tuple = FindInstruction(module, "tuple");
+ module->entry_computation()->set_root_instruction(tuple->mutable_operand(0));
+ TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte));
+ TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple));
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
+ EXPECT_TRUE(remover_changed);
+
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(root->has_sharding());
+ EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1));
+}
+
// Tests that text dumps of domain instructions can be parsed back, in the
// specific case of null shardings.
TEST_F(HloDomainTest, DumpParseNullSharding) {
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 39036e205e..4f91d619ef 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -88,6 +88,12 @@ std::vector<PassThrough> LocatePassThroughDomainLinks(
VLOG(2) << " " << instruction->ToString();
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ pass_through.emplace_back(nullptr, instruction);
+ VLOG(2) << "Found passthrough domain link:";
+ VLOG(2) << " <root>";
+ VLOG(2) << " " << instruction->ToString();
+ }
}
return pass_through;
}
@@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
tuple, 0));
gte->set_sharding(sharding);
- TF_RETURN_IF_ERROR(
- pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ if (pass_through.user != nullptr) {
+ TF_RETURN_IF_ERROR(
+ pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ } else {
+ pass_through.operand->parent()->set_root_instruction(gte);
+ }
}
return Status::OK();
}