aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 16:20:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 16:24:30 -0700
commit111745bdf9338926626d3aeec6736c75f55c608a (patch)
treef9b2648c164976e672aacf6ffc7e14930b859a58 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parentfad8d28c8afb5bbedabb91110b07fc130a9ca36e (diff)
[TF:XLA] Align the two implementations of CanShareOperandBufferWithUser.
Eventually (when TuplePointsToAnalysis is removed), there will be only one implementation left. Also, use early return instead of else-if to make the code less indented. PiperOrigin-RevId: 206240067
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc39
1 files changed, 21 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index de1a32d8bd..1abfcb7703 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -1017,19 +1017,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
if (user->opcode() == HloOpcode::kFusion) {
+ if (fusion_can_share_buffer_ != nullptr) {
+ return fusion_can_share_buffer_(user, operand);
+ }
// Get the parameter associated with 'operand';
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
const HloValue& value = GetValueDefinedAt(fusion_param, operand_index);
- if (value.uses().size() != 1) {
- if (MultiDynamicSliceUseShareSameIndices(value.uses())) {
- return true;
- }
- return false;
+ if (MultiDynamicSliceUseShareSameIndices(value.uses())) {
+ return true;
}
- const HloUse& use = value.uses()[0];
-
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
user->fusion_kind() == HloInstruction::FusionKind::kInput) {
if (user->fused_expression_root()->opcode() ==
@@ -1039,13 +1037,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Returns true iff there is exactly one use of 'operand' at shape index
// 'operand_index', and this singleton use is the fused root at operand
// index 0.
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == 0;
- } else {
- return AreTransitiveUsesElementwiseOrTuple(fusion_param);
+ if (value.uses().size() == 1) {
+ const HloUse& use = value.uses()[0];
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == 0;
+ }
+ return false;
}
- } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
- user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
+ return AreTransitiveUsesElementwiseOrTuple(fusion_param);
+ }
+ if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
// Check if one operand of kAdd fused root is kDot or kConvolution.
@@ -1066,11 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Returns true iff there is exactly one use of 'operand' at shape index
// 'operand_index', and this singleton use is the fused root (at operand
// index 'other_add_operand_index').
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == other_add_operand_index;
- } else if (fusion_can_share_buffer_ != nullptr &&
- fusion_can_share_buffer_(user, operand)) {
- return true;
+ if (value.uses().size() == 1) {
+ const HloUse& use = value.uses()[0];
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == other_add_operand_index;
+ }
+ return false;
}
}