diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-11 00:58:07 +0100 |
---|---|---|
committer | Thomas Köppe <tkoeppe@google.com> | 2017-08-11 01:01:31 +0100 |
commit | 9103096c12faa1fdbdf806c2422c7d84fc2d0642 (patch) | |
tree | 6e31c5a689b0d8797826cb1c7ad97be133a31bc9 /tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc | |
parent | 822603aed3f20159f06284af5ce35efa81b95ed6 (diff) |
Merged commit includes the following changes:
164923041 by meheff:
Make HloAliasAnalysis updatable after changes to the HLO graph.
As part of this change make HloAliasAnalysis a thinner layer which
basically only holds a map from HloValue to HloBuffer and vice versa.
--
PiperOrigin-RevId: 164923041
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc | 241 |
1 files changed, 189 insertions, 52 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index e7a30ae13b..e2815d6e64 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -41,18 +43,25 @@ class HloAliasAnalysisTest : public HloTestBase { // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. - const HloAliasAnalysis& RunAnalysis() { + HloAliasAnalysis& RunAnalysis() { analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); return *analysis_; } - // Return a vector of the buffers in the buffer set at the current position. + // Return a vector of the buffers in the buffer set at the current position + // sorted by buffer id. std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction, const ShapeIndex& index = {}) const { + std::set<HloBuffer::Id> buffer_ids; + for (const HloValue* value : analysis_->dataflow_analysis() + .GetValueSet(instruction, index) + .values()) { + buffer_ids.insert(analysis_->GetBufferContainingValue(*value).id()); + } + std::vector<HloBuffer> buffers; - for (const HloBuffer* buffer : - analysis_->GetBufferSet(instruction, index).buffers()) { - buffers.push_back(*buffer); + for (HloBuffer::Id id : buffer_ids) { + buffers.push_back(analysis_->GetBuffer(id)); } return buffers; } @@ -122,8 +131,8 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { GetValueDefinedAt(instruction)); } - EXPECT_FALSE(analysis.GetInstructionBufferSet(add).IsAmbiguous()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(add).IsDistinct()); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(add)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(add)); EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } @@ -166,12 +175,12 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { // Verify the positions of an aliased buffer. EXPECT_THAT( - analysis.GetUniqueBufferAt(param0).positions(), + analysis.GetUniqueBufferAt(param0).ComputePositions(), UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, HloPosition{gte0, {}})); - EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(tuple)); EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } @@ -191,12 +200,12 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { const HloAliasAnalysis& analysis = RunAnalysis(); EXPECT_THAT( - analysis.GetUniqueBufferAt(param0).positions(), + analysis.GetUniqueBufferAt(param0).ComputePositions(), UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}}, HloPosition{tuple, {2}})); - EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); - EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple)); + EXPECT_FALSE(analysis.InstructionBuffersAreDistinct(tuple)); EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } @@ -226,16 +235,16 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { const HloAliasAnalysis& analysis = RunAnalysis(); // Verify aliasing of the kCall operands and the subcomputation parameters. - EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(), + EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(), UnorderedElementsAre(HloPosition{constant1, {}}, HloPosition{subparam0, {}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(), + EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(), UnorderedElementsAre(HloPosition{constant2, {}}, HloPosition{subparam1, {}})); // The subcomputation root and the kCall itself should alias. EXPECT_THAT( - analysis.GetUniqueBufferAt(add).positions(), + analysis.GetUniqueBufferAt(add).ComputePositions(), UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}})); EXPECT_FALSE(AnyValuesInSameBufferInterfere()); @@ -266,10 +275,10 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { const HloAliasAnalysis& analysis = RunAnalysis(); - EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(), + EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(), UnorderedElementsAre(HloPosition{constant1, {}}, HloPosition{subparam0, {}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(), + EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(), UnorderedElementsAre(HloPosition{constant2, {}}, HloPosition{subparam1, {}})); @@ -277,7 +286,7 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { // and the first parameter of the subcomputation because 'call1' it is passed // as an argument to the subcomputation in 'call2'. EXPECT_THAT( - analysis.GetUniqueBufferAt(add).positions(), + analysis.GetUniqueBufferAt(add).ComputePositions(), UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}}, HloPosition{subparam0, {}}, HloPosition{call2, {}})); @@ -287,10 +296,10 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { EXPECT_THAT(GetBuffersAt(subparam1), UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2))); - EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsAmbiguous()); - EXPECT_FALSE(analysis.GetInstructionBufferSet(subparam1).IsAmbiguous()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsDistinct()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam1).IsDistinct()); + EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(subparam0)); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(subparam1)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam0)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam1)); EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } @@ -352,23 +361,26 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { const HloAliasAnalysis& analysis = RunAnalysis(); // Verify the positions of the aliased while buffers. - EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).positions(), - UnorderedElementsAre( - HloPosition{tuple, {}}, HloPosition{xla_while, {}}, - HloPosition{body_param, {}}, HloPosition{body_tuple, {}}, - HloPosition{cond_param, {}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).positions(), - UnorderedElementsAre( - HloPosition{constant1, {}}, HloPosition{tuple, {0}}, - HloPosition{xla_while, {0}}, HloPosition{body_param, {0}}, - HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}}, - HloPosition{cond_param, {0}})); - EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).positions(), - UnorderedElementsAre( - HloPosition{constant2, {}}, HloPosition{tuple, {1}}, - HloPosition{xla_while, {1}}, HloPosition{body_param, {1}}, - HloPosition{body_element_1, {}}, HloPosition{add, {}}, - HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}})); + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).ComputePositions(), + UnorderedElementsAre(HloPosition{tuple, {}}, HloPosition{xla_while, {}}, + HloPosition{body_param, {}}, + HloPosition{body_tuple, {}}, + HloPosition{cond_param, {}})); + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).ComputePositions(), + UnorderedElementsAre( + HloPosition{constant1, {}}, HloPosition{tuple, {0}}, + HloPosition{xla_while, {0}}, HloPosition{body_param, {0}}, + HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}}, + HloPosition{cond_param, {0}})); + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), + UnorderedElementsAre( + HloPosition{constant2, {}}, HloPosition{tuple, {1}}, + HloPosition{xla_while, {1}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}})); EXPECT_THAT( GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})), @@ -446,6 +458,9 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); module_->AddEntryComputation(builder.Build()); + FlattenCallGraph flattener; + TF_ASSERT_OK(flattener.Run(module_.get()).status()); + const HloAliasAnalysis& analysis = RunAnalysis(); EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}), @@ -689,15 +704,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { analysis.GetUniqueBufferAt(constant3), analysis.GetUniqueBufferAt(constant4))); - EXPECT_FALSE(analysis.GetInstructionBufferSet(select11).IsAmbiguous()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsAmbiguous()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsAmbiguous()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsAmbiguous()); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select11)); + EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select12)); + EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select34)); + EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select1234)); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select11).IsDistinct()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsDistinct()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsDistinct()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsDistinct()); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select11)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select12)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select34)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select1234)); EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } @@ -776,11 +791,11 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { GetValueDefinedAt(body_param, /*index=*/{0}), GetValueDefinedAt(cond_param, /*index=*/{0}), GetValueDefinedAt(negate))); - EXPECT_FALSE(analysis.GetInstructionBufferSet(select).IsAmbiguous()); - EXPECT_FALSE(analysis.GetInstructionBufferSet(xla_while).IsAmbiguous()); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select)); + EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(xla_while)); - EXPECT_TRUE(analysis.GetInstructionBufferSet(select).IsDistinct()); - EXPECT_TRUE(analysis.GetInstructionBufferSet(xla_while).IsDistinct()); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select)); + EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(xla_while)); // The two operands of the select get flattened into the same buffer resulting // in liveness interference. @@ -805,5 +820,127 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { analysis.GetUniqueBufferAt(bitcast)); } +TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) { + // Test updating alias analysis after modifying a module with an array shaped + // while: + // + // body(F32[] %param): + // %negate = Negate(%param) + // + // condition(F32[] %param): + // return Constant(false) + // + // entry: + // %constant = Constant(1.0) + // %exp = Exp(%constant) + // return While(%exp, body, condition) + // + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, body_param)); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, exp)); + module_->AddEntryComputation(builder.Build()); + + HloAliasAnalysis& analysis = RunAnalysis(); + + // Sanity check some alias information. + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(body_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(cond_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(negate)); + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(xla_while)); + + // Set the body root to the body_param. Previously it was Negate(body_param). + body->set_root_instruction(body_param); + + // Prior to updating, verify that the analysis is no longer valid. + Status verify_status = analysis.VerifyAgainstReference(); + EXPECT_FALSE(verify_status.ok()); + + analysis.UpdateAfterChangingRoot(/*old_root=*/negate, + /*new_root*/ body_param); + + // Analysis should be valid after the update. + TF_ASSERT_OK(analysis.VerifyAgainstReference()); + + // The exponential should now pass through the body transparently. + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(body_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(cond_param)); + EXPECT_NE(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(negate)); + EXPECT_EQ(analysis.GetUniqueBufferAt(exp), + analysis.GetUniqueBufferAt(xla_while)); + + // Now replace the operand of the while with %constant (was %exp). + TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant)); + analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp, + /*new_operand=*/constant); + + // Analysis should be valid after the update. + TF_ASSERT_OK(analysis.VerifyAgainstReference()); + + EXPECT_EQ(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(body_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(cond_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(xla_while)); + EXPECT_NE(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(exp)); + EXPECT_NE(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(negate)); + + // And finally make the negate the root of the body again. + body->set_root_instruction(negate); + analysis.UpdateAfterChangingRoot(/*old_root=*/body_param, + /*new_root*/ negate); + + // Analysis should be valid after the update. + TF_ASSERT_OK(analysis.VerifyAgainstReference()); + + EXPECT_EQ(analysis.GetUniqueBufferAt(negate), + analysis.GetUniqueBufferAt(body_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate), + analysis.GetUniqueBufferAt(cond_param)); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate), + analysis.GetUniqueBufferAt(xla_while)); + EXPECT_EQ(analysis.GetUniqueBufferAt(constant), + analysis.GetUniqueBufferAt(negate)); + + auto value_of = [&analysis](const HloInstruction* instruction) { + return &analysis.dataflow_analysis().GetValueDefinedAt(instruction); + }; + EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(), + UnorderedElementsAre(value_of(body_param), value_of(cond_param), + value_of(negate), value_of(constant), + value_of(xla_while))); +} + +// Test update tuple element. + } // namespace } // namespace xla |