aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-23 14:42:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 14:54:43 -0700
commitf226eb3717a0df815579178f4393d4e68cbe08fc (patch)
treeecb70dd59c323a541dd9a11fb13018a73981f025 /tensorflow/compiler/xla/literal_util_test.cc
parent4f127e9019ff32f5c165550d535e4ad0fa587dd6 (diff)
[XLA] Adds a C64 type to XLA, with actual compilation support coming soon.
PiperOrigin-RevId: 173172916
Diffstat (limited to 'tensorflow/compiler/xla/literal_util_test.cc')
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc85
1 files changed, 83 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index e7dedd0821..a9af4849e2 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f));
ASSERT_EQ("0.5", f16_lit->ToString());
+
+ auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
+ ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
@@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) {
EXPECT_NE(*tuple1, *different_tuple);
}
+TEST_F(LiteralUtilTest, C64Equality) {
+ // Test equality with tuples.
+ auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+
+ // Tuple with the same elements. One element is shared with the original
+ // tuple, the other is a clone of the element in the original tuple.
+ auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ EXPECT_EQ(*vector, *vector_clone);
+
+ auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
+ EXPECT_NE(*vector, *vector_reversed);
+}
+
TEST_F(LiteralUtilTest, IsAllTuple) {
auto element1 = Literal::CreateR0<float>(0.0);
auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
@@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ complex64 c8_9 = {8, 9};
+ EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+
auto uint64_max = std::numeric_limits<uint64>::max();
EXPECT_FALSE(Literal::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
@@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) {
Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
}
+TEST_F(LiteralUtilTest, IsAllComplex) {
+ // IsAllComplex always returns false when the literal is not complex.
+ EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0));
+ EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0));
+ EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0));
+ EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0));
+ EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0));
+ EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0));
+
+ complex64 c8_9 = {8, 9};
+ complex64 c7_9 = {7, 9};
+ EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})
+ ->IsAllComplex({8.0f, 9.0f}));
+ EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})
+ ->IsAllComplex({8.0f, 9.0f}));
+ EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}})
+ ->IsAllComplex({8.0f, 9.0f}));
+}
+
TEST_F(LiteralUtilTest, IsZero) {
auto scalar_zero = Literal::CreateR0<float>(0.0f);
auto scalar_one = Literal::CreateR0<float>(1.0f);
@@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) {
EXPECT_TRUE(array->IsZero({0, 2}));
EXPECT_TRUE(array->IsZero({1, 1}));
EXPECT_FALSE(array->IsZero({1, 2}));
+
+ auto complex_zero = Literal::CreateR0<complex64>(0.0f);
+ auto complex_nonzero = Literal::CreateR0<complex64>(0.5f);
+ EXPECT_TRUE(complex_zero->IsZero({}));
+ EXPECT_FALSE(complex_nonzero->IsZero({}));
}
template <typename T>
class LiteralUtilTestTemplated : public ::testing::Test {};
-using TestedTypes = ::testing::Types<float, int32, uint32>;
+using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
@@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) {
EXPECT_EQ(output, *expected);
}
-TEST_F(LiteralUtilTest, PopulateR2U64) {
+TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output;
output.PopulateR1<uint64>({{77, 88}});
auto expected = Literal::CreateR1<uint64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
+TEST_F(LiteralUtilTest, PopulateR1C64) {
+ Literal output;
+ output.PopulateR1<complex64>({{77, 88}});
+ auto expected = Literal::CreateR1<complex64>({{77, 88}});
+ EXPECT_EQ(output, *expected);
+}
+
+TEST_F(LiteralUtilTest, PopulateR2C64) {
+ Literal output;
+ output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
+ auto expected =
+ Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
+ EXPECT_EQ(output, *expected);
+}
+
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output;
output.PopulateWithValue<float>(2.5f, {});
@@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
EXPECT_EQ(output, *expected);
}
+TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
+ Literal output;
+ output.PopulateWithValue<complex64>({4, 2}, {2, 2});
+ auto expected =
+ Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
+ EXPECT_EQ(output, *expected);
+}
+
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
Literal output;
half h(0.25f);
@@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
}}, layout_r4_dim0major_);
+ auto c64 = Literal::CreateR4WithLayout<complex64>({{
+ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
+ {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
+ {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
+ }}, layout_r4_dim0major_);
// clang-format on
std::unique_ptr<Literal> conv;
@@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
conv = u32->Convert(F16).ConsumeValueOrDie();
EXPECT_EQ(*conv, *f16);
+ conv = s32->Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(*conv, *c64);
+
+ conv = f16->Convert(C64).ConsumeValueOrDie();
+ EXPECT_EQ(*conv, *c64);
+
EXPECT_EQ(s32->Convert(TUPLE).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(s32->Convert(S16).status().code(),
tensorflow::error::INVALID_ARGUMENT);
EXPECT_EQ(s32->Convert(U16).status().code(),
tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(c64->Convert(F32).status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ EXPECT_EQ(c64->Convert(S32).status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
}
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {