aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_test.cc
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-03-15 17:53:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-16 16:19:29 -0700
commit3ae663ccc5d08976e0f547d5b2ece35067a6673e (patch)
tree45810d02a11acc9b45e4268fd00548929ea91ffd /tensorflow/core/framework/tensor_test.cc
parenta0d21ec39cc3f18781d2d37798aa328e12f92844 (diff)
Merge changes from github.
Change: 117301677
Diffstat (limited to 'tensorflow/core/framework/tensor_test.cc')
-rw-r--r--tensorflow/core/framework/tensor_test.cc72
1 files changed, 66 insertions, 6 deletions
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index ec0fb57aad..13896f9177 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -47,12 +47,17 @@ TEST(TensorTest, DataType_Traits) {
// Unfortunately. std::complex::complex() initializes (0, 0).
EXPECT_FALSE(std::is_trivial<complex64>::value);
- EXPECT_FALSE(std::is_trivial<std::complex<double>>::value);
+ EXPECT_FALSE(std::is_trivial<complex128>::value);
EXPECT_TRUE(std::is_trivial<float[2]>::value);
- struct MyComplex {
+ EXPECT_TRUE(std::is_trivial<double[2]>::value);
+ struct MyComplex64 {
float re, im;
};
- EXPECT_TRUE(std::is_trivial<MyComplex>::value);
+ EXPECT_TRUE(std::is_trivial<MyComplex64>::value);
+ struct MyComplex128 {
+ double re, im;
+ };
+ EXPECT_TRUE(std::is_trivial<MyComplex128>::value);
}
template <typename T>
@@ -420,13 +425,19 @@ TEST(Tensor_Bool, SimpleWithHelper) {
test::ExpectTensorEqual<bool>(t1, t2);
}
-TEST(Tensor_Complex, Simple) {
+TEST(Tensor_Complex, Simple64) {
Tensor t(DT_COMPLEX64, {4, 5, 3, 7});
t.flat<complex64>().setRandom();
TestCopies<complex64>(t);
}
-TEST(Tensor_Complex, SimpleWithHelper) {
+TEST(Tensor_Complex, Simple128) {
+ Tensor t(DT_COMPLEX128, {4, 5, 3, 7});
+ t.flat<complex128>().setRandom();
+ TestCopies<complex128>(t);
+}
+
+TEST(Tensor_Complex, SimpleWithHelper64) {
{
Tensor t1 = test::AsTensor<complex64>({0,
{1, 1},
@@ -444,7 +455,7 @@ TEST(Tensor_Complex, SimpleWithHelper) {
test::ExpectTensorEqual<complex64>(t2, t3);
}
- // Does some numeric operations for complex numbers.
+ // Does some numeric operations for complex64 numbers.
{
const float PI = std::acos(-1);
const complex64 rotate_45 = std::polar(1.0f, PI / 4);
@@ -475,6 +486,55 @@ TEST(Tensor_Complex, SimpleWithHelper) {
}
}
+TEST(Tensor_Complex, SimpleWithHelper128) {
+ {
+ Tensor t1 = test::AsTensor<complex128>({0,
+ {1, 1},
+ complex128(2),
+ complex128(3, 3),
+ complex128(0, 4),
+ complex128(2, 5)},
+ {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<complex128>() = t1.flat<complex128>() * complex128(0, 2);
+ Tensor t3 = test::AsTensor<complex128>(
+ {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}},
+ // shape
+ {2, 3});
+ test::ExpectTensorEqual<complex128>(t2, t3);
+ }
+
+ // Does some numeric operations for complex128 numbers.
+ {
+ const double PI = std::acos(-1);
+ const complex128 rotate_45 = std::polar(1.0, PI / 4);
+
+ // x contains all the 8-th root of unity.
+ Tensor x(DT_COMPLEX128, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ x.vec<complex128>()(i) = std::pow(rotate_45, i);
+ }
+
+ // Shift the roots by 45 degree.
+ Tensor y(DT_COMPLEX128, TensorShape({8}));
+ y.vec<complex128>() = x.vec<complex128>() * rotate_45;
+ Tensor y_expected(DT_COMPLEX128, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ y_expected.vec<complex128>()(i) = std::pow(rotate_45, i + 1);
+ }
+ test::ExpectTensorNear<complex128>(y, y_expected, 1e-5);
+
+ // Raise roots to the power of 8.
+ Tensor z(DT_COMPLEX128, TensorShape({8}));
+ z.vec<complex128>() = x.vec<complex128>().pow(8);
+ Tensor z_expected(DT_COMPLEX128, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ z_expected.vec<complex128>()(i) = 1;
+ }
+ test::ExpectTensorNear<complex128>(z, z_expected, 1e-5);
+ }
+}
+
// On the alignment.
//
// As of 2015/8, tensorflow::Tensor allocates its buffer with 32-byte