diff options
author | 2017-10-23 14:42:57 -0700 | |
---|---|---|
committer | 2017-10-23 14:54:43 -0700 | |
commit | f226eb3717a0df815579178f4393d4e68cbe08fc (patch) | |
tree | ecb70dd59c323a541dd9a11fb13018a73981f025 /tensorflow/compiler/xla/literal_util_test.cc | |
parent | 4f127e9019ff32f5c165550d535e4ad0fa587dd6 (diff) |
[XLA] Adds a C64 type to XLA, with actual compilation support coming soon.
PiperOrigin-RevId: 173172916
Diffstat (limited to 'tensorflow/compiler/xla/literal_util_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/literal_util_test.cc | 85 |
1 files changed, 83 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index e7dedd0821..a9af4849e2 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); + + auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f}); + ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) { EXPECT_NE(*tuple1, *different_tuple); } +TEST_F(LiteralUtilTest, C64Equality) { + // Test equality with tuples. + auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(*vector, *vector_clone); + + auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(*vector, *vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = Literal::CreateR0<float>(0.0); auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}}); @@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8)); + complex64 c8_9 = {8, 9}; + EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8)); + auto uint64_max = std::numeric_limits<uint64>::max(); EXPECT_FALSE(Literal::CreateR2<uint64>( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) @@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) { Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } +TEST_F(LiteralUtilTest, IsAllComplex) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0)); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}}) + ->IsAllComplex({8.0f, 9.0f})); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0<float>(0.0f); auto scalar_one = Literal::CreateR0<float>(1.0f); @@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) { EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); + + auto complex_zero = Literal::CreateR0<complex64>(0.0f); + auto complex_nonzero = Literal::CreateR0<complex64>(0.5f); + EXPECT_TRUE(complex_zero->IsZero({})); + EXPECT_FALSE(complex_nonzero->IsZero({})); } template <typename T> class LiteralUtilTestTemplated : public ::testing::Test {}; -using TestedTypes = ::testing::Types<float, int32, uint32>; +using TestedTypes = ::testing::Types<float, int32, uint32, complex64>; TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { @@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) { EXPECT_EQ(output, *expected); } -TEST_F(LiteralUtilTest, PopulateR2U64) { +TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output; output.PopulateR1<uint64>({{77, 88}}); auto expected = Literal::CreateR1<uint64>({{77, 88}}); EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateR1C64) { + Literal output; + output.PopulateR1<complex64>({{77, 88}}); + auto expected = Literal::CreateR1<complex64>({{77, 88}}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateR2C64) { + Literal output; + output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + auto expected = + Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue<float>(2.5f, {}); @@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { + Literal output; + output.PopulateWithValue<complex64>({4, 2}, {2, 2}); + auto expected = + Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); @@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); + auto c64 = Literal::CreateR4WithLayout<complex64>({{ + {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, + {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, + {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, + }}, layout_r4_dim0major_); // clang-format on std::unique_ptr<Literal> conv; @@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = u32->Convert(F16).ConsumeValueOrDie(); EXPECT_EQ(*conv, *f16); + conv = s32->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + + conv = f16->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(F32).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(S32).status().code(), + tensorflow::error::INVALID_ARGUMENT); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { |