diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_sharding_metadata.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_sharding_metadata.cc | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 39036e205e..4f91d619ef 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -88,6 +88,12 @@ std::vector<PassThrough> LocatePassThroughDomainLinks( VLOG(2) << " " << instruction->ToString(); } } + if (instruction == instruction->parent()->root_instruction()) { + pass_through.emplace_back(nullptr, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " <root>"; + VLOG(2) << " " << instruction->ToString(); + } } return pass_through; } @@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), tuple, 0)); gte->set_sharding(sharding); - TF_RETURN_IF_ERROR( - pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + if (pass_through.user != nullptr) { + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } else { + pass_through.operand->parent()->set_root_instruction(gte); + } } return Status::OK(); } |