diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 0cd0ab36fc..5c8d97b2d1 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } +TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { + // parameter 0 aliased with output 1 and parameter 1 aliased with output 0. + // + // (p0 , p1) + // \ / + // \ / + // alias X + // / \ + // / \ + // (p0 , p1) + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Every Ops in this graph are aliased with each other. + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { + // Test a simple single while instruction can be aliased with input and output + // of the computation. + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %param1 = param1 + // %while = While(%param1, body, condition) + // %while_1 = GTE(%while, 0) + // %while_2 = GTE(%while, 1) + // %negate_1 = Negate(%while_1) + // %negate_2 = Negate(%while_2) + // return Tuple(negate_1, negate_2) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, param)); + auto while_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0)); + auto while_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1)); + auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_1)); + auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_2)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), + UnorderedElementsAre(GetValueDefinedAt(param, {1}), + GetValueDefinedAt(xla_while, /*index=*/{1}), + GetValueDefinedAt(body_param, {1}), + GetValueDefinedAt(cond_param, {1}), + GetValueDefinedAt(add), + GetValueDefinedAt(negate_2))); + + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), + UnorderedElementsAre( + HloPosition{param, {1}}, HloPosition{xla_while, {1}}, + HloPosition{while_element_2, {}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}}, + HloPosition{cond_param, {1}}, HloPosition{negate_2, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); +} + TEST_F(HloAliasAnalysisTest, SingleCall) { // Test a single call of a subcomputation. The subcomputation adds its two // array-shaped parameters. |