diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/tuple_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/tuple_test.cc | 152 |
1 files changed, 70 insertions, 82 deletions
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index f2b3b49015..619d2a388b 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0<float>(constant_scalar).get(), - LiteralUtil::CreateR1<float>(constant_vector).get(), - LiteralUtil::CreateR2<float>(constant_matrix).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<float>(constant_scalar), + LiteralUtil::CreateR1<float>(constant_vector), + LiteralUtil::CreateR2<float>(constant_matrix)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests a tuple made of scalar constants. @@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0<float>(constant_scalar1).get(), - LiteralUtil::CreateR0<float>(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<float>(constant_scalar1), + LiteralUtil::CreateR0<float>(constant_scalar2)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests the creation of tuple data. @@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1<float>(&builder, constant_vector), ConstantR2<float>(&builder, constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0<float>(constant_scalar).get(), - LiteralUtil::CreateR1<float>(constant_vector).get(), - LiteralUtil::CreateR2<float>(constant_matrix).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<float>(constant_scalar), + LiteralUtil::CreateR1<float>(constant_vector), + LiteralUtil::CreateR2<float>(constant_matrix)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of tuple data. @@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(), - LiteralUtil::CreateR1<float>({}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of an empty tuple. @@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. @@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2<float>(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2<float>(constant_matrix).get(), - LiteralUtil::CreateR1<float>(constant_vector).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2<float>(constant_matrix), + LiteralUtil::CreateR1<float>(constant_vector)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { @@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(), - LiteralUtil::CreateR0<bool>(!direction).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<bool>(direction), + LiteralUtil::CreateR0<bool>(!direction)}); - ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()}, error_spec_); } } @@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1<float>(&builder, vec1)}); Select(ConstantR0<bool>(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(), - LiteralUtil::CreateR1<float>(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { @@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1<float>(&builder, vec1)}); Select(ConstantR0<bool>(&builder, true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(), - LiteralUtil::CreateR1<float>(vec2).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { @@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0<bool>(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(), - LiteralUtil::CreateR1<float>(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { @@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0}); auto expected_s = LiteralUtil::CreateR0<float>(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + LiteralUtil::MakeTuple({&expected_v1, &expected_s}); auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0}); - auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { @@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr<GlobalData> data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( - { - LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(), - }) - .get(), - LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}), })) .ConsumeValueOrDie(); @@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr<GlobalData> arg0 = client_ - ->TransferToServer(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0<complex64>({1, 2}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}) - .get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0<complex64>({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}), LiteralUtil::CreateR2<complex64>( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}}) - .get()}) - .get()})) + {{10000, 20000}, {30000, 40000}}})})})) .ConsumeValueOrDie(); std::unique_ptr<GlobalData> arg1 = client_ ->TransferToServer( - *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}})) + LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); auto sum = LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = absl::make_unique<Literal>(sum->shape()); - ASSERT_TRUE(prod->Populate<complex64>( - [&sum](absl::Span<const int64> indexes) { - return sum->Get<complex64>(indexes) * - (indexes[indexes.size() - 1] == 0 - ? complex64(1, 2) - : complex64(1, -2)); - }) + Literal prod(sum.shape()); + ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) { + return sum.Get<complex64>(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) .ok()); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), - LiteralUtil::CreateR0<complex64>({123, 456}).get()}); - ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices({prod, sum}), + LiteralUtil::CreateR0<complex64>({123, 456})}); + ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3})); - auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})), + result)); } // Disabled on interpreter due to lack of outfeed. @@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest, tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { TF_EXPECT_OK(Execute(std::move(module), - {param0.get(), param1.get(), param1.get(), - param0.get(), param4.get()}) + {¶m0, ¶m1, ¶m1, ¶m0, ¶m4}) .status()); })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3})); - auto literal = Literal::CreateFromShape(expected->shape()); + auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), *literal)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); + backend().default_stream_executor(), expected.shape(), literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } } // namespace |