aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/round_trip_transfer_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc50
1 files changed, 27 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index f334a8c131..a8193c2eac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -46,61 +46,62 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*Literal::CreateR0<int32>(42));
+ RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*Literal::CreateR0<float>(42.0));
+ RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*Literal::CreateR1<float>({}));
+ RoundTripTest(*LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*Literal::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(*Literal::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(
+ *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*Literal::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*Literal::CreateR4<float>({{
+ RoundTripTest(*LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -108,33 +109,36 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*Literal::MakeTuple({}));
+ RoundTripTest(*LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
- Literal::CreateR1<float>({3, 4}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR1<float>({}).get(),
- Literal::CreateR1<float>({3, 4}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR0<float>(1.0).get(),
- Literal::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
+ LiteralUtil::CreateR1<int>({2, 3}).get()}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*Literal::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace