aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 17:31:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 17:34:27 -0700
commita350f66ed250c3dee43cc27b0778c3759f07e810 (patch)
tree6416b984ae6ffb1147937ce0a61f376176680c01 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parentc7776b996d88c83e0e94aa0fde0f32c4fb23144b (diff)
Add backend specific lambda to decide whether a fusion instruction can share buffer with its operand.
PiperOrigin-RevId: 201615582
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc18
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 08a705b18d..f529c0dad7 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -81,12 +81,14 @@ bool MultiDynamicSliceUseShareSameIndices(
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
-HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
- bool bitcast_defines_value)
+HloDataflowAnalysis::HloDataflowAnalysis(
+ const HloModule& module, bool ssa_form, bool bitcast_defines_value,
+ const FusionCanShareBufferFunction& fusion_can_share_buffer)
: module_(module),
ssa_form_(ssa_form),
bitcast_defines_value_(bitcast_defines_value),
- call_graph_(CallGraph::Build(&module)) {}
+ call_graph_(CallGraph::Build(&module)),
+ fusion_can_share_buffer_(fusion_can_share_buffer) {}
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
@@ -855,12 +857,13 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
- const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
+ const HloModule& module, bool ssa_form, bool bitcast_defines_value,
+ const FusionCanShareBufferFunction& fusion_can_share_buffer) {
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto dataflow_analysis = WrapUnique(
- new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
+ auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis(
+ module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
dataflow_analysis->Propagate();
@@ -1041,6 +1044,9 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// index 'other_add_operand_index').
return use.instruction == user->fused_expression_root() &&
use.operand_number == other_add_operand_index;
+ } else if (fusion_can_share_buffer_ != nullptr &&
+ fusion_can_share_buffer_(user, operand)) {
+ return true;
}
}