diff options
author | 2018-07-09 04:32:39 -0700 | |
---|---|---|
committer | 2018-07-09 04:36:18 -0700 | |
commit | 75a80aa3aa32fa12b74387b67f3d73aca532fc89 (patch) | |
tree | 6136b53f0a9850838b6540429e4041b4e4c78cef | |
parent | 0063183a62f69c2523a3982c70d72e231428fb60 (diff) |
Fix domain removal when the root instruction is an empty domain
If a domain become empty because the various optimizations removed all
instruction from it then we have to re-add some instruction to make sure
the user supplied sharding is still respected.
This is especially important for the root instruction as the user will
expect the data to be available on the device they requested it. Before
this CL we failed to insert the tuple->gte sequence into the empty
domain due to a bug where we only considered cases where we have an exit
domain what is not the case for the root instruction.
PiperOrigin-RevId: 203744534
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_domain_map.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_domain_test.cc | 38 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_metadata.cc | 14 |
3 files changed, 55 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 957024a64a..9e096320db 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -62,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } + if (instruction == instruction->parent()->root_instruction()) { + auto domain = MakeUnique<DomainMetadata::Domain>(); + domain->enter_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 3859e4cae6..00b2c860a7 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -436,6 +436,44 @@ ENTRY entry { HloSharding::AssignDevice(0)})); } +TEST_F(HloDomainTest, EmptyRootDomain) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + %param = f32[1] parameter(0), sharding={maximal device=0} + %tuple = (f32[1]) tuple(%param), + sharding={maximal device=1} + ROOT %gte = f32[1] get-tuple-element(%tuple), index=0, + sharding={maximal device=1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + + // Remove %tuple and %gte (tuple simplification) + HloInstruction* gte = FindInstruction(module, "gte"); + HloInstruction* tuple = FindInstruction(module, "tuple"); + module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); + TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); + TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(root->has_sharding()); + EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1)); +} + // Tests that text dumps of domain instructions can be parsed back, in the // specific case of null shardings. TEST_F(HloDomainTest, DumpParseNullSharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 39036e205e..4f91d619ef 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -88,6 +88,12 @@ std::vector<PassThrough> LocatePassThroughDomainLinks( VLOG(2) << " " << instruction->ToString(); } } + if (instruction == instruction->parent()->root_instruction()) { + pass_through.emplace_back(nullptr, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " <root>"; + VLOG(2) << " " << instruction->ToString(); + } } return pass_through; } @@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), tuple, 0)); gte->set_sharding(sharding); - TF_RETURN_IF_ERROR( - pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + if (pass_through.user != nullptr) { + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } else { + pass_through.operand->parent()->set_root_instruction(gte); + } } return Status::OK(); } |