diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc | 79 |
1 files changed, 41 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index ea7e479d66..be3fae5161 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*Literal::CreateR0<bool>(true)); + TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*Literal::CreateR1<uint32>({1, 2, 3})); + TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0minor)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0major)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*Literal::CreateR4( + TestInfeedRoundTrip(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { TestInfeedRoundTrip( - *Literal::MakeTuple({Literal::CreateR1<uint32>({1, 2, 3}).get(), - Literal::CreateR0<bool>(false).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(), + LiteralUtil::CreateR0<bool>(false).get()})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*Literal::MakeTuple({})); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -156,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { }); // Send 5 Infeed data of shape F32[3]. - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({1, 2, 3}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({4, 5, 6}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1<float>({10, 11, 12}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1<float>({13, 14, 15}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); @@ -247,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(), - Literal::CreateR0<bool>(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(), + LiteralUtil::CreateR0<bool>(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({3, 4}).get(), - Literal::CreateR0<bool>(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(), + LiteralUtil::CreateR0<bool>(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({5, 6}).get(), - Literal::CreateR0<bool>(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(), + LiteralUtil::CreateR0<bool>(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({7, 8}).get(), - Literal::CreateR0<bool>(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(), + LiteralUtil::CreateR0<bool>(false).get()}))); // Asynchronously launch the execution on the device. std::unique_ptr<GlobalData> result; @@ -272,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({1, 2, 3}).get(), - Literal::CreateR0<bool>(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(), + LiteralUtil::CreateR0<bool>(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({7, 8, 9}).get(), - Literal::CreateR0<bool>(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(), + LiteralUtil::CreateR0<bool>(false).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1<float>({4, 5, 6}).get(), - Literal::CreateR0<bool>(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(), + LiteralUtil::CreateR0<bool>(true).get()}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. |