aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc175
1 files changed, 175 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index 0cd0ab36fc..5c8d97b2d1 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
+TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+
+ auto negate0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
+ module_->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+
+ // Cannot alias an output twice.
+ ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
+
+ const HloAliasAnalysis& analysis = RunAnalysis();
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
+}
+
+TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
+ // parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
+ //
+ // (p0 , p1)
+ // \ /
+ // \ /
+ // alias X
+ // / \
+ // / \
+ // (p0 , p1)
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ module_->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
+
+ // Cannot alias an output twice.
+ ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+
+ const HloAliasAnalysis& analysis = RunAnalysis();
+
+ // Every Ops in this graph are aliased with each other.
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
+}
+
+TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
+ // Test a simple single while instruction can be aliased with input and output
+ // of the computation.
+ //
+ // body((F32[], F32[]) %tuple_param):
+ // %add = Add(%tuple_param{0}, %tuple_param{1})
+ // return Tuple(%tuple_param{0}, %add)
+ //
+ // condition((F32[], F32[]) %tuple_param):
+ // return Constant(false)
+ //
+ // entry:
+ // %param1 = param1
+ // %while = While(%param1, body, condition)
+ // %while_1 = GTE(%while, 0)
+ // %while_2 = GTE(%while, 1)
+ // %negate_1 = Negate(%while_1)
+ // %negate_2 = Negate(%while_2)
+ // return Tuple(negate_1, negate_2)
+ //
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
+
+ // Element 0 passes transparently through the body.
+ auto body_builder = HloComputation::Builder("body");
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ auto body_element_0 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
+ auto body_element_1 = body_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
+ auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
+ auto body_tuple = body_builder.AddInstruction(
+ HloInstruction::CreateTuple({body_element_0, add}));
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
+
+ // Condition computation trivially returns a constant "false".
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ HloComputation* condition =
+ module_->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "p0"));
+
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(tuple_shape, condition, body, param));
+ auto while_element_1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
+ auto while_element_2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
+ auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, while_element_1));
+ auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, while_element_2));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
+ module_->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
+ TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
+ /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
+
+ const HloAliasAnalysis& analysis = RunAnalysis();
+
+ EXPECT_THAT(
+ GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
+ UnorderedElementsAre(GetValueDefinedAt(param, {1}),
+ GetValueDefinedAt(xla_while, /*index=*/{1}),
+ GetValueDefinedAt(body_param, {1}),
+ GetValueDefinedAt(cond_param, {1}),
+ GetValueDefinedAt(add),
+ GetValueDefinedAt(negate_2)));
+
+ EXPECT_THAT(
+ analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
+ UnorderedElementsAre(
+ HloPosition{param, {1}}, HloPosition{xla_while, {1}},
+ HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
+ HloPosition{body_element_1, {}}, HloPosition{add, {}},
+ HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
+ HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
+
+ EXPECT_FALSE(AnyValuesInSameBufferInterfere());
+}
+
TEST_F(HloAliasAnalysisTest, SingleCall) {
// Test a single call of a subcomputation. The subcomputation adds its two
// array-shaped parameters.