aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-06-07 18:42:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 18:45:05 -0700
commita9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (patch)
tree849afa106cb882d918cfa0a633134bb89f7f014f /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent2f41346cbc0c8ecb915983a1f8711fd0d0ccc50e (diff)
[DataFlowAnalysis] Be less conservative on loop fusion nodes when reusing buffer.
- Previously, we say we cannot reuse operand buffer for a loop fusion node if any of the fusion's inputs is a broadcast or reshape. That's too conservative since in theory we can still reuse the operand's buffer if all the users of that particular operand are elementwise. This CL implements that. - Also fixed a bug in previous code where a dynamic update fusion node that ends with convert (added for bf16) is not caught by the if condition currectly. PiperOrigin-RevId: 199731488
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc123
1 files changed, 123 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 5798326dcb..db1822ec47 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1974,6 +1974,89 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ NonElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "param0"));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
+
+ auto reverse = builder.AddInstruction(
+ HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {reverse, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ MultiOutputFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ Shape in_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, in_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, in_shape, "param1"));
+
+ auto copy0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
+ auto copy1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {0}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {1}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {0}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {1}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ ElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
+
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {exp, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
auto builder = HloComputation::Builder(TestName());
@@ -2048,6 +2131,46 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
fusion, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ FusedDynamicUpdateSliceWithConvertCantShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ auto convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape_bf16, gte1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape_bf16, convert1, update, starts));
+
+ auto convert2 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {convert2, dynamic_update_slice, starts, update, convert1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction can't share with tuple element 1.
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
auto builder = HloComputation::Builder(TestName());