aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc79
1 files changed, 41 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index ea7e479d66..be3fae5161 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*Literal::CreateR0<bool>(true));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*Literal::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(
- *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
- r3_dim0minor));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0minor));
- TestInfeedRoundTrip(
- *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
- r3_dim0major));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*Literal::CreateR4(
+ TestInfeedRoundTrip(*LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
TestInfeedRoundTrip(
- *Literal::MakeTuple({Literal::CreateR1<uint32>({1, 2, 3}).get(),
- Literal::CreateR0<bool>(false).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*Literal::MakeTuple({}));
+ TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -156,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
});
// Send 5 Infeed data of shape F32[3].
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({1, 2, 3})));
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({4, 5, 6})));
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*Literal::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*Literal::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
@@ -247,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({3, 4}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({5, 6}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({7, 8}).get(),
- Literal::CreateR0<bool>(false).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -272,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({1, 2, 3}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({7, 8, 9}).get(),
- Literal::CreateR0<bool>(false).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({4, 5, 6}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.