diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/copy_insertion.cc | 72 |
1 files changed, 44 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 33d8338809..e0ce2e3555 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -472,6 +472,10 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { + if (ShapeUtil::IsToken(value_a->shape())) { + // Token values have no representation and cannot interfere. + continue; + } for (const HloValue* value_b : buffer.values()) { if (value_a != value_b) { DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, @@ -613,7 +617,10 @@ class CopyRemover { VLOG(2) << copy->name() << " is not removable"; return false; } - + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } const CopyNodes& copy_node = copy_map_.at(copy); ValueNode* src = copy_node.src; ValueNode* dest = copy_node.dest; @@ -947,28 +954,6 @@ class CopyRemover { BufferValueTracker buffer_value_tracker_; }; -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); - - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); - } - } - } - return Status::OK(); -} - // Add copies to address special constraints on the roots of computations not // related to live range interference: // @@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { HloInstruction* instruction = pair.first; const ShapeTree<bool>& indices_to_copy = pair.second; + ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape()); std::vector<HloInstruction*> users = instruction->users(); TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); + instruction, &indices_to_copy, &copies_added)); for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } + // Special case copies are not eligible for later copy elision passes. + indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { + if (has_copy) { + HloInstruction* copy = *copies_added.mutable_element(index); + if (copy != nullptr) { + copy->SetCopyElisionAllowed(false); + } + } + }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1097,6 +1092,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) { } // namespace +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) { + MaybeDumpModule("after adding copies to resolve interference", *module); + + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, + HloAliasAnalysis::Run(module)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + !ContainsKey(copies_to_exclude, instruction->unique_id()) && + instruction->CopyElisionAllowed()) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } + } + MaybeDumpModule("after removing unnecessary copies", *module); + + return Status::OK(); +} + StatusOr<bool> CopyInsertion::Run(HloModule* module) { // Copy insertion is performed in three steps: // @@ -1158,14 +1178,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - MaybeDumpModule("after adding copies to resolve interference", *module); - DependencyHloOrdering ordering(module); TF_RETURN_IF_ERROR( RemoveUnnecessaryCopies(ordering, existing_copies, module)); - MaybeDumpModule("after removing unnecessary copies", *module); - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); MaybeDumpModule("after adding special-case copies", *module); |