aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Jeremy Lau <lauj@google.com>2018-05-16 15:54:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 15:59:29 -0700
commit250415665dbd6ea200e8fc17e1c61eaf32312343 (patch)
treea1655b82ae0d54808c3c9a21e8838a1cbe602cc5 /tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
parent9fd3485db92d6bfee928dfaaba3dc69938bab8b6 (diff)
Move DoesNotUseOperandBuffer and CanShareOperandBufferWithUser from
liveness_util to methods on TuplePointsToAnalysis and HloDataflowAnalysis. PiperOrigin-RevId: 196903216
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc341
1 files changed, 341 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 07f69b8e13..5798326dcb 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1873,5 +1873,346 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));
+class HloDataflowAnalysisTestBase : public HloTestBase {
+ protected:
+ void BuildModule(std::unique_ptr<HloComputation> computation) {
+ module_ = CreateNewModule();
+ computation_ = module_->AddEntryComputation(std::move(computation));
+ }
+
+ void RunAnalysis() {
+ CHECK_NOTNULL(module_.get());
+ dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
+ }
+
+ void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
+ BuildModule(std::move(computation));
+ RunAnalysis();
+ }
+
+ std::unique_ptr<HloModule> module_;
+ HloComputation* computation_ = nullptr;
+ std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
+};
+
+class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {};
+
+TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // GetTupleElement instructions only access the top-level buffer of their
+ // operand.
+ EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
+ EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
+ EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
+ EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
+}
+
+TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction never uses tuple element 0, but does use element 1.
+ EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
+ EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
+}
+
+class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
+
+TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape shape = ShapeUtil::MakeShape(F32, {8});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
+ auto log = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape in_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, in_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, in_shape, "param1"));
+ auto result = builder.AddInstruction(
+ HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ result, {}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ result, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape shape = ShapeUtil::MakeShape(F32, {8});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, gte1, update, starts));
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dynamic_update_slice, starts, update, gte1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction can share with tuple element 1.
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0},
+ fusion, {}));
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape update_shape = ShapeUtil::MakeShape(F32, {4});
+ Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, update_shape, "update"));
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, starts_shape, "starts"));
+ auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape, data, update, starts));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // The DynamicUpdateSlice instruction can share with the data operand, but not
+ // with update or starts.
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto a = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ auto b = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto dot = builder.AddInstruction(
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto add_operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape, HloOpcode::kAdd, dot, add_operand));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, dot}, HloInstruction::FusionKind::kOutput);
+ RunAnalysis();
+
+ // Output fused dot add should be able to share buffer with 'add_operand'.
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+
+ auto reverse = builder.AddInstruction(
+ HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
+
+ auto two = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, two, reverse}, HloInstruction::FusionKind::kOutput);
+ RunAnalysis();
+
+ // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+
+ auto make_cond = [this, &data_shape]() {
+ auto builder = HloComputation::Builder(TestName() + ".Cond");
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
+ return builder.Build();
+ };
+
+ auto make_body = [this, &data_shape]() {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
+ return builder.Build();
+ };
+
+ module_ = CreateNewModule();
+ HloComputation* cond_computation =
+ module_->AddEmbeddedComputation(make_cond());
+ HloComputation* body_computation =
+ module_->AddEmbeddedComputation(make_body());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "data"));
+ auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
+ data_shape, cond_computation, body_computation, data));
+ computation_ = module_->AddEntryComputation(builder.Build());
+
+ RunAnalysis();
+
+ // The While instruction can share with the data operand.
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
+}
+
+// Tests that Call can alias operand buffer if the only use of the operand
+// in the called computation is an elementwise instruction.
+TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
+ Shape shape = ShapeUtil::MakeShape(F32, {8});
+ // Build sub-computation with fusion root.
+ auto sub_builder = HloComputation::Builder(TestName() + "_sub");
+ auto sub_param = sub_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "sub_param"));
+ auto one = sub_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto ones = sub_builder.AddInstruction(
+ HloInstruction::CreateBroadcast(shape, one, {1}));
+ auto add = sub_builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
+
+ module_ = CreateNewModule();
+ auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
+ sub_computation->CreateFusionInstruction({add, ones},
+ HloInstruction::FusionKind::kLoop);
+
+ // Build entry-computation with kCall which calls 'sub_computation'.
+ auto builder = HloComputation::Builder(TestName());
+
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto reverse =
+ builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
+ auto call = builder.AddInstruction(
+ HloInstruction::CreateCall(shape, {reverse}, sub_computation));
+ computation_ = module_->AddEntryComputation(builder.Build());
+
+ RunAnalysis();
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
+}
+
} // namespace
} // namespace xla