aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/tuple_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/tuple_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc152
1 files changed, 70 insertions, 82 deletions
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index f2b3b49015..619d2a388b 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
@@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
- LiteralUtil::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar1),
+ LiteralUtil::CreateR0<float>(constant_scalar2)});
- ConstantLiteral(&builder, *value);
- ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+ ConstantLiteral(&builder, value);
+ ComputeAndCompareTuple(&builder, value, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) {
ConstantR1<float>(&builder, constant_vector),
ConstantR2<float>(&builder, constant_matrix)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<float>(constant_scalar).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get(),
- LiteralUtil::CreateR2<float>(constant_matrix).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(constant_scalar),
+ LiteralUtil::CreateR1<float>(constant_vector),
+ LiteralUtil::CreateR2<float>(constant_matrix)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of tuple data.
@@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
Tuple(&builder,
{ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
- LiteralUtil::CreateR1<float>({}).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
@@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
Tuple(&builder, {});
auto expected = LiteralUtil::MakeTuple({});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
@@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ConstantR2<float>(&builder, constant_matrix)});
Tuple(&builder,
{GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR2<float>(constant_matrix).get(),
- LiteralUtil::CreateR1<float>(constant_vector).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR2<float>(constant_matrix),
+ LiteralUtil::CreateR1<float>(constant_vector)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
@@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
- LiteralUtil::CreateR0<bool>(!direction).get()});
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<bool>(direction),
+ LiteralUtil::CreateR0<bool>(!direction)});
- ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
+ ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
error_spec_);
}
}
@@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, TuplesInAMap) {
@@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
ConstantR1<float>(&builder, vec1)});
Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
- LiteralUtil::CreateR1<float>(vec2).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
@@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected =
- LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
- LiteralUtil::CreateR1<float>(vec1).get()});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
@@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) {
auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ LiteralUtil::MakeTuple({&expected_v1, &expected_s});
auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
- auto expected =
- LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
- ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+ ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
@@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple({
- LiteralUtil::MakeTuple(
- {
- LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
- })
- .get(),
- LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::MakeTupleFromSlices({
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
+ }),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
}))
.ConsumeValueOrDie();
@@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
- LiteralUtil::MakeTuple(
- {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
- .get(),
+ ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR0<complex64>({1, 2}),
+ LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
- {{10000, 20000}, {30000, 40000}}})
- .get()})
- .get()}))
+ {{10000, 20000}, {30000, 40000}}})})}))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
->TransferToServer(
- *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
auto sum =
LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
{{1011, 2022}, {3031, 4042}},
{{10011, 20022}, {30031, 40042}}});
- auto prod = absl::make_unique<Literal>(sum->shape());
- ASSERT_TRUE(prod->Populate<complex64>(
- [&sum](absl::Span<const int64> indexes) {
- return sum->Get<complex64>(indexes) *
- (indexes[indexes.size() - 1] == 0
- ? complex64(1, 2)
- : complex64(1, -2));
- })
+ Literal prod(sum.shape());
+ ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
+ return sum.Get<complex64>(indexes) *
+ (indexes[indexes.size() - 1] == 0
+ ? complex64(1, 2)
+ : complex64(1, -2));
+ })
.ok());
- auto expected = LiteralUtil::MakeTuple(
- {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
- LiteralUtil::CreateR0<complex64>({123, 456}).get()});
- ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
+ auto expected = LiteralUtil::MakeTupleFromSlices(
+ {LiteralUtil::MakeTupleFromSlices({prod, sum}),
+ LiteralUtil::CreateR0<complex64>({123, 456})});
+ ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
.ValueOrDie();
auto param =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
- auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
+ auto result = ExecuteNoHloPasses(std::move(module), {&param});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
- *result));
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+ result));
}
// Disabled on interpreter due to lack of outfeed.
@@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest,
tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "execute_thread", [&] {
TF_EXPECT_OK(Execute(std::move(module),
- {param0.get(), param1.get(), param1.get(),
- param0.get(), param4.get()})
+ {&param0, &param1, &param1, &param0, &param4})
.status());
}));
auto expected =
LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
- auto literal = Literal::CreateFromShape(expected->shape());
+ auto literal = Literal::CreateFromShape(expected.shape());
TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
- backend().default_stream_executor(), expected->shape(), *literal));
- EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
+ backend().default_stream_executor(), expected.shape(), literal));
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
}
} // namespace