aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc97
1 files changed, 50 insertions, 47 deletions
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 5734f28407..0ac8df4271 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase {
TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
// tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, constant3}));
@@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
// tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, constant3}));
@@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
// Create a tuple which contains duplicate elements.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant, constant, constant}));
@@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
// the same.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto copy = builder.AddInstruction(
@@ -317,9 +317,10 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
// Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
BuildModuleAndRunAnalysis(builder.Build());
@@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
// RecvDone forwards its operand tuple element at {0} to the output.
auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
- ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
+ ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
BuildModuleAndRunAnalysis(builder.Build());
@@ -355,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
- ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
+ ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
}
TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
@@ -363,18 +365,18 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
// set containing the union of both sides.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant2, constant2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
BuildModuleAndRunAnalysis(builder.Build());
@@ -401,9 +403,9 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, tuple_shape, "param1"));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, param0, param1));
+ tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
@@ -441,18 +443,18 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
// Select from two identical tuples. The result should not be ambiguous.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
BuildModuleAndRunAnalysis(builder.Build());
@@ -472,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
// the right values.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto inner_tuple2 = builder.AddInstruction(
@@ -486,9 +488,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
BuildModuleAndRunAnalysis(builder.Build());
@@ -519,9 +521,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
// have the operand of the bitcast in its points-to set.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
constant2->shape(), HloOpcode::kBitcast, constant2));
auto tuple =
@@ -555,9 +557,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::MakeTuple({Literal::CreateR2<float>({{1.0}, {2.0}}).get(),
- Literal::CreateR1<float>({2.0, 42}).get()})));
+ auto tuple_constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
@@ -577,9 +580,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
// times. Verify buffer alias sets.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple = builder.AddInstruction(
@@ -618,7 +621,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
auto tuple_element1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
// Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
auto update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, tuple_element1, ones));
@@ -866,9 +869,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -960,9 +963,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -1014,9 +1017,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
@@ -1025,7 +1028,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -1047,7 +1050,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -1055,7 +1058,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
@@ -1120,7 +1123,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "sub_param"));
auto one = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
auto add = sub_builder.AddInstruction(