aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-11 03:36:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 03:39:13 -0700
commit73c479056aca52e83f84d7df4132c420f1f3feed (patch)
tree6c16da114e18344020418a884f931fdaf83ea5a6
parent3a1d8bd815b5216bc9515801e4d59cf3ebd1126d (diff)
[TuplePointsToAnalysis] Be less conservative on loop fusion nodes when reusing buffer.
Previously, we say we cannot reuse operand buffer for a loop fusion node if any of the fusion's inputs is a broadcast or reshape. That's too conservative since in theory we can still reuse the operand's buffer if all the users of that particular operand are elementwise. This CL implements that. Allow sharding operand and output buffer for partially elementwise fusions. The same change have been recently applyed to DataFlowAnalysis as well but we use this pass in many places as well. PiperOrigin-RevId: 200028414
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc27
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc25
2 files changed, 41 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index bb634e6573..eb6d1ada6b 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -723,15 +723,16 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
return false;
}
if (user->opcode() == HloOpcode::kFusion) {
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- // Loop fusion with kDynamicUpdateSlice fused root.
- //
- // Returns true iff there is exactly one use of 'operand' at shape index
- // 'operand_index', and this singleton use is the fused root at operand
- // index 0.
- return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ if (user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
+ }
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -789,8 +790,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
return param_uses.size() == 1 && param_uses[0].first == callee_root &&
callee_root->IsElementwiseOnOperand(param_uses[0].second);
}
- // Check if 'user' is element-wise.
- return user->IsElementwise();
+ // Loop fusions that contain transposing copies won't reach here as they have
+ // different layouts, which fails the check in the beginning of this function.
+ //
+ // Multi-output fusion will fail the check here as tuples are not considered
+ // an elementwise operation.
+ return user->IsElementwiseOnOperand(user->operand_index(operand));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index f558316b05..5734f28407 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1148,5 +1148,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
call, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) {
+ Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32});
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16});
+
+ auto builder = HloComputation::Builder(TestName() + "_fusion");
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, full_shape, "full"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, broadcast_shape, "small"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(full_shape, param1, {0}));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ full_shape, HloOpcode::kAdd, param0, broadcast));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, broadcast}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {}));
+}
+
} // namespace
} // namespace xla