aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc86
1 files changed, 38 insertions, 48 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index ab3d846403..36fb9b43aa 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -76,15 +76,6 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
policy.copy_parameters_and_constants = true;
policy.copy_root_replicated_buffers = true;
}
- for (const CallSite& site : node.caller_callsites()) {
- // The AddCopiesForConditional() already adds copies, but the copy remover
- // removes them, so we re-add them by returning the policy here. But really
- // the copy remover should not be removing them.
- if (site.instruction()->opcode() == HloOpcode::kConditional) {
- policy.copy_parameters_and_constants = true;
- policy.copy_root_replicated_buffers = true;
- }
- }
return policy;
}
@@ -360,26 +351,6 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) {
return Status::OK();
}
-// Add kCopy instructions to the given module to guarantee there is no
-// live-range interference. Generally interference can only occur around kWhile
-// instructions which have update-in-place semantics.
-Status AddCopiesToResolveInterference(HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
-
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kWhile) {
- TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
- } else if (instruction->opcode() == HloOpcode::kConditional) {
- TF_RETURN_IF_ERROR(
- AddCopiesForConditional(*alias_analysis, instruction));
- }
- }
- }
- return Status::OK();
-}
-
// Class for removing unnecessary copies from the module.
//
// kCopy instructions are added conservatively to guarantee no live range
@@ -954,6 +925,36 @@ class CopyRemover {
BufferValueTracker buffer_value_tracker_;
};
+void MaybeDumpModule(const string& message, const HloModule& module) {
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << message;
+ XLA_VLOG_LINES(3, module.ToString());
+ hlo_graph_dumper::MaybeDumpHloModule(module, message);
+ }
+}
+
+} // namespace
+
+// Add kCopy instructions to the given module to guarantee there is no
+// live-range interference. Generally interference can only occur around kWhile
+// instructions which have update-in-place semantics.
+Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
+ HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
+
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kWhile) {
+ TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
+ } else if (instruction->opcode() == HloOpcode::kConditional) {
+ TF_RETURN_IF_ERROR(
+ AddCopiesForConditional(*alias_analysis, instruction));
+ }
+ }
+ }
+ return Status::OK();
+}
+
// Add copies to address special constraints on the roots of computations not
// related to live range interference:
//
@@ -964,9 +965,10 @@ class CopyRemover {
//
// (3) Constants and parameters cannot be live out of the entry computation
//
-Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
+Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
+ HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
+ HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
// Identify which shape indices of which instructions need to be copied. Store
// these results in 'instructions_to_copy'.
@@ -1074,32 +1076,20 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
return Status::OK();
}
-Status VerifyNoLiveRangeInterference(HloModule* module) {
+Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
+ HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
DependencyHloOrdering ordering(module);
TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
return Status::OK();
}
-void MaybeDumpModule(const string& message, const HloModule& module) {
- if (VLOG_IS_ON(3)) {
- VLOG(3) << message;
- XLA_VLOG_LINES(3, module.ToString());
- hlo_graph_dumper::MaybeDumpHloModule(module, message);
- }
-}
-
-} // namespace
-
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering, HloModule* module,
- const HloDataflowAnalysis::FusionCanShareBufferFunction&
- fusion_can_share_buffer) {
+Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
+ HloModule* module) {
MaybeDumpModule("after adding copies to resolve interference", *module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module, fusion_can_share_buffer));
+ HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
CopyRemover copy_remover(*alias_analysis, ordering, module);
XLA_VLOG_LINES(3, copy_remover.ToString());