diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/constant_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/constant_op_test.py | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 36fc93f107..cc769ec274 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -28,7 +28,7 @@ class ConstantTest(tf.test.TestCase): np_ans = np.array(x) with self.test_session(use_gpu=False): tf_ans = tf.convert_to_tensor(x).eval() - if np_ans.dtype in [np.float32, np.float64, np.complex64]: + if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]: self.assertAllClose(np_ans, tf_ans) else: self.assertAllEqual(np_ans, tf_ans) @@ -37,7 +37,7 @@ class ConstantTest(tf.test.TestCase): np_ans = np.array(x) with self.test_session(use_gpu=True): tf_ans = tf.convert_to_tensor(x).eval() - if np_ans.dtype in [np.float32, np.float64, np.complex64]: + if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]: self.assertAllClose(np_ans, tf_ans) else: self.assertAllEqual(np_ans, tf_ans) @@ -70,7 +70,7 @@ class ConstantTest(tf.test.TestCase): (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int64)) self._testAll(np.empty((2, 0, 5)).astype(np.int64)) - def testSComplex(self): + def testComplex64(self): self._testAll( np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5]).astype( np.complex64)) @@ -79,6 +79,15 @@ class ConstantTest(tf.test.TestCase): np.complex64)) self._testAll(np.empty((2, 0, 5)).astype(np.complex64)) + def testComplex128(self): + self._testAll( + np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5]).astype( + np.complex128)) + self._testAll(np.complex( + 1, 2) * np.random.normal(size=30).reshape([2, 3, 5]).astype( + np.complex128)) + self._testAll(np.empty((2, 0, 5)).astype(np.complex128)) + def testString(self): self._testCpu(np.array([tf.compat.as_bytes(str(x)) for x in np.arange(-15, 15)]).reshape([2, 3, 5])) @@ -295,7 +304,7 @@ class ZerosTest(tf.test.TestCase): # Test explicit type control for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, - tf.complex64, tf.int64]: + tf.complex64, tf.complex128, tf.int64]: z = tf.zeros([2, 3], dtype=dtype) self.assertEqual(z.dtype, dtype) self.assertEqual([2, 3], z.get_shape()) @@ -323,7 +332,7 @@ class ZerosLikeTest(tf.test.TestCase): def testZerosLikeCPU(self): for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, - tf.complex64, tf.int64]: + tf.complex64, tf.complex128, tf.int64]: self._compareZeros(dtype, False) def testZerosLikeGPU(self): @@ -399,9 +408,9 @@ class OnesTest(tf.test.TestCase): self.assertEqual(z.dtype, tf.float32) self.assertEqual([2, 3], z.get_shape()) # Test explicit type control - for dtype in [tf.float32, tf.float64, tf.int32, + for dtype in (tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, - tf.complex64, tf.int64]: + tf.complex64, tf.complex128, tf.int64): z = tf.ones([2, 3], dtype=dtype) self.assertEqual(z.dtype, dtype) self.assertEqual([2, 3], z.get_shape()) @@ -415,7 +424,7 @@ class OnesLikeTest(tf.test.TestCase): def testOnesLike(self): for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, - tf.complex64, tf.int64]: + tf.complex64, tf.complex128, tf.int64]: numpy_dtype = dtype.as_numpy_dtype with self.test_session(): # Creates a tensor of non-zero values with shape 2 x 3. @@ -466,10 +475,14 @@ class FillTest(tf.test.TestCase): np_ans = np.array([[-42] * 3] * 2).astype(np.int64) self._compareAll([2, 3], np_ans[0][0], np_ans) - def testFillComplex(self): + def testFillComplex64(self): np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64) self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) + def testFillComplex128(self): + np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128) + self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) + def testFillString(self): np_ans = np.array([[b"yolo"] * 3] * 2) with self.test_session(use_gpu=False): |