aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-11 00:58:07 +0100
committerGravatar Thomas Köppe <tkoeppe@google.com>2017-08-11 01:01:31 +0100
commit9103096c12faa1fdbdf806c2422c7d84fc2d0642 (patch)
tree6e31c5a689b0d8797826cb1c7ad97be133a31bc9 /tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
parent822603aed3f20159f06284af5ce35efa81b95ed6 (diff)
Merged commit includes the following changes:
164923041 by meheff: Make HloAliasAnalysis updatable after changes to the HLO graph. As part of this change make HloAliasAnalysis a thinner layer which basically only holds a map from HloValue to HloBuffer and vice versa. -- PiperOrigin-RevId: 164923041
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc241
1 files changed, 189 insertions, 52 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index e7a30ae13b..e2815d6e64 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
@@ -27,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -41,18 +43,25 @@ class HloAliasAnalysisTest : public HloTestBase {
// Run alias analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
- const HloAliasAnalysis& RunAnalysis() {
+ HloAliasAnalysis& RunAnalysis() {
analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie();
return *analysis_;
}
- // Return a vector of the buffers in the buffer set at the current position.
+ // Return a vector of the buffers in the buffer set at the current position
+ // sorted by buffer id.
std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
+ std::set<HloBuffer::Id> buffer_ids;
+ for (const HloValue* value : analysis_->dataflow_analysis()
+ .GetValueSet(instruction, index)
+ .values()) {
+ buffer_ids.insert(analysis_->GetBufferContainingValue(*value).id());
+ }
+
std::vector<HloBuffer> buffers;
- for (const HloBuffer* buffer :
- analysis_->GetBufferSet(instruction, index).buffers()) {
- buffers.push_back(*buffer);
+ for (HloBuffer::Id id : buffer_ids) {
+ buffers.push_back(analysis_->GetBuffer(id));
}
return buffers;
}
@@ -122,8 +131,8 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) {
GetValueDefinedAt(instruction));
}
- EXPECT_FALSE(analysis.GetInstructionBufferSet(add).IsAmbiguous());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(add).IsDistinct());
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(add));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(add));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
@@ -166,12 +175,12 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
// Verify the positions of an aliased buffer.
EXPECT_THAT(
- analysis.GetUniqueBufferAt(param0).positions(),
+ analysis.GetUniqueBufferAt(param0).ComputePositions(),
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloPosition{gte0, {}}));
- EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(tuple));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
@@ -191,12 +200,12 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_THAT(
- analysis.GetUniqueBufferAt(param0).positions(),
+ analysis.GetUniqueBufferAt(param0).ComputePositions(),
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloPosition{tuple, {2}}));
- EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
- EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
+ EXPECT_FALSE(analysis.InstructionBuffersAreDistinct(tuple));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
@@ -226,16 +235,16 @@ TEST_F(HloAliasAnalysisTest, SingleCall) {
const HloAliasAnalysis& analysis = RunAnalysis();
// Verify aliasing of the kCall operands and the subcomputation parameters.
- EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
+ EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
UnorderedElementsAre(HloPosition{constant1, {}},
HloPosition{subparam0, {}}));
- EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
+ EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
UnorderedElementsAre(HloPosition{constant2, {}},
HloPosition{subparam1, {}}));
// The subcomputation root and the kCall itself should alias.
EXPECT_THAT(
- analysis.GetUniqueBufferAt(add).positions(),
+ analysis.GetUniqueBufferAt(add).ComputePositions(),
UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
@@ -266,10 +275,10 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
const HloAliasAnalysis& analysis = RunAnalysis();
- EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
+ EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
UnorderedElementsAre(HloPosition{constant1, {}},
HloPosition{subparam0, {}}));
- EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
+ EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
UnorderedElementsAre(HloPosition{constant2, {}},
HloPosition{subparam1, {}}));
@@ -277,7 +286,7 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
// and the first parameter of the subcomputation because 'call1' it is passed
// as an argument to the subcomputation in 'call2'.
EXPECT_THAT(
- analysis.GetUniqueBufferAt(add).positions(),
+ analysis.GetUniqueBufferAt(add).ComputePositions(),
UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
HloPosition{subparam0, {}}, HloPosition{call2, {}}));
@@ -287,10 +296,10 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
EXPECT_THAT(GetBuffersAt(subparam1),
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2)));
- EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsAmbiguous());
- EXPECT_FALSE(analysis.GetInstructionBufferSet(subparam1).IsAmbiguous());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsDistinct());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam1).IsDistinct());
+ EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(subparam0));
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(subparam1));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam0));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam1));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
@@ -352,23 +361,26 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
const HloAliasAnalysis& analysis = RunAnalysis();
// Verify the positions of the aliased while buffers.
- EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).positions(),
- UnorderedElementsAre(
- HloPosition{tuple, {}}, HloPosition{xla_while, {}},
- HloPosition{body_param, {}}, HloPosition{body_tuple, {}},
- HloPosition{cond_param, {}}));
- EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).positions(),
- UnorderedElementsAre(
- HloPosition{constant1, {}}, HloPosition{tuple, {0}},
- HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
- HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
- HloPosition{cond_param, {0}}));
- EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).positions(),
- UnorderedElementsAre(
- HloPosition{constant2, {}}, HloPosition{tuple, {1}},
- HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
- HloPosition{body_element_1, {}}, HloPosition{add, {}},
- HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
+ EXPECT_THAT(
+ analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).ComputePositions(),
+ UnorderedElementsAre(HloPosition{tuple, {}}, HloPosition{xla_while, {}},
+ HloPosition{body_param, {}},
+ HloPosition{body_tuple, {}},
+ HloPosition{cond_param, {}}));
+ EXPECT_THAT(
+ analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).ComputePositions(),
+ UnorderedElementsAre(
+ HloPosition{constant1, {}}, HloPosition{tuple, {0}},
+ HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
+ HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
+ HloPosition{cond_param, {0}}));
+ EXPECT_THAT(
+ analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
+ UnorderedElementsAre(
+ HloPosition{constant2, {}}, HloPosition{tuple, {1}},
+ HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
+ HloPosition{body_element_1, {}}, HloPosition{add, {}},
+ HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
@@ -446,6 +458,9 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
module_->AddEntryComputation(builder.Build());
+ FlattenCallGraph flattener;
+ TF_ASSERT_OK(flattener.Run(module_.get()).status());
+
const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
@@ -689,15 +704,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) {
analysis.GetUniqueBufferAt(constant3),
analysis.GetUniqueBufferAt(constant4)));
- EXPECT_FALSE(analysis.GetInstructionBufferSet(select11).IsAmbiguous());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsAmbiguous());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsAmbiguous());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsAmbiguous());
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select11));
+ EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select12));
+ EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select34));
+ EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select1234));
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select11).IsDistinct());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsDistinct());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsDistinct());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsDistinct());
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select11));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select12));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select34));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select1234));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
@@ -776,11 +791,11 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
GetValueDefinedAt(body_param, /*index=*/{0}),
GetValueDefinedAt(cond_param, /*index=*/{0}),
GetValueDefinedAt(negate)));
- EXPECT_FALSE(analysis.GetInstructionBufferSet(select).IsAmbiguous());
- EXPECT_FALSE(analysis.GetInstructionBufferSet(xla_while).IsAmbiguous());
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select));
+ EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(xla_while));
- EXPECT_TRUE(analysis.GetInstructionBufferSet(select).IsDistinct());
- EXPECT_TRUE(analysis.GetInstructionBufferSet(xla_while).IsDistinct());
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select));
+ EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(xla_while));
// The two operands of the select get flattened into the same buffer resulting
// in liveness interference.
@@ -805,5 +820,127 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
analysis.GetUniqueBufferAt(bitcast));
}
+TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) {
+ // Test updating alias analysis after modifying a module with an array shaped
+ // while:
+ //
+ // body(F32[] %param):
+ // %negate = Negate(%param)
+ //
+ // condition(F32[] %param):
+ // return Constant(false)
+ //
+ // entry:
+ // %constant = Constant(1.0)
+ // %exp = Exp(%constant)
+ // return While(%exp, body, condition)
+ //
+ auto body_builder = HloComputation::Builder("body");
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape_, HloOpcode::kNegate, body_param));
+ HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
+
+ // Condition computation trivially returns a constant "false".
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloComputation* condition =
+ module_->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
+ module_->AddEntryComputation(builder.Build());
+
+ HloAliasAnalysis& analysis = RunAnalysis();
+
+ // Sanity check some alias information.
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(body_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(cond_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(negate));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(xla_while));
+
+ // Set the body root to the body_param. Previously it was Negate(body_param).
+ body->set_root_instruction(body_param);
+
+ // Prior to updating, verify that the analysis is no longer valid.
+ Status verify_status = analysis.VerifyAgainstReference();
+ EXPECT_FALSE(verify_status.ok());
+
+ analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
+ /*new_root*/ body_param);
+
+ // Analysis should be valid after the update.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ // The exponential should now pass through the body transparently.
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(body_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(cond_param));
+ EXPECT_NE(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(negate));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
+ analysis.GetUniqueBufferAt(xla_while));
+
+ // Now replace the operand of the while with %constant (was %exp).
+ TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
+ analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
+ /*new_operand=*/constant);
+
+ // Analysis should be valid after the update.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
+ analysis.GetUniqueBufferAt(body_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
+ analysis.GetUniqueBufferAt(cond_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
+ analysis.GetUniqueBufferAt(xla_while));
+ EXPECT_NE(analysis.GetUniqueBufferAt(constant),
+ analysis.GetUniqueBufferAt(exp));
+ EXPECT_NE(analysis.GetUniqueBufferAt(constant),
+ analysis.GetUniqueBufferAt(negate));
+
+ // And finally make the negate the root of the body again.
+ body->set_root_instruction(negate);
+ analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
+ /*new_root*/ negate);
+
+ // Analysis should be valid after the update.
+ TF_ASSERT_OK(analysis.VerifyAgainstReference());
+
+ EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
+ analysis.GetUniqueBufferAt(body_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
+ analysis.GetUniqueBufferAt(cond_param));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
+ analysis.GetUniqueBufferAt(xla_while));
+ EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
+ analysis.GetUniqueBufferAt(negate));
+
+ auto value_of = [&analysis](const HloInstruction* instruction) {
+ return &analysis.dataflow_analysis().GetValueDefinedAt(instruction);
+ };
+ EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(),
+ UnorderedElementsAre(value_of(body_param), value_of(cond_param),
+ value_of(negate), value_of(constant),
+ value_of(xla_while)));
+}
+
+// Test update tuple element.
+
} // namespace
} // namespace xla