diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-21 17:31:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-21 17:34:27 -0700 |
commit | a350f66ed250c3dee43cc27b0778c3759f07e810 (patch) | |
tree | 6416b984ae6ffb1147937ce0a61f376176680c01 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | |
parent | c7776b996d88c83e0e94aa0fde0f32c4fb23144b (diff) |
Add backend specific lambda to decide whether a fusion instruction can share buffer with its operand.
PiperOrigin-RevId: 201615582
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 08a705b18d..f529c0dad7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -81,12 +81,14 @@ bool MultiDynamicSliceUseShareSameIndices( using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value) +HloDataflowAnalysis::HloDataflowAnalysis( + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(&module)) {} + call_graph_(CallGraph::Build(&module)), + fusion_can_share_buffer_(fusion_can_share_buffer) {} bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { @@ -855,12 +857,13 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run( - const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) { VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique( - new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); + auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); dataflow_analysis->Propagate(); @@ -1041,6 +1044,9 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return use.instruction == user->fused_expression_root() && use.operand_number == other_add_operand_index; + } else if (fusion_can_share_buffer_ != nullptr && + fusion_can_share_buffer_(user, operand)) { + return true; } } |