diff options
Diffstat (limited to 'tensorflow/compiler/tests/nullary_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/nullary_ops_test.py | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index f985c5d2d9..38cb2f83ef 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase): output.run() def testConstants(self): - constants = [ - np.float32(42), - np.array([], dtype=np.float32), - np.array([1, 2], dtype=np.float32), - np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), - np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], - dtype=np.float32), - np.array([[[]], [[]]], dtype=np.float32), - np.array([[[[1]]]], dtype=np.float32), - ] - for c in constants: - self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + for dtype in self.numeric_types: + constants = [ + dtype(42), + np.array([], dtype=dtype), + np.array([1, 2], dtype=dtype), + np.array([7, 7, 7, 7, 7], dtype=dtype), + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + + def testComplexConstants(self): + for dtype in self.complex_types: + constants = [ + dtype(42 + 3j), + np.array([], dtype=dtype), + np.ones([50], dtype=dtype) * (3 + 4j), + np.array([1j, 2 + 1j], dtype=dtype), + np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4 + 6j], [5, 6]], + [[10 + 7j, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1 + 3j]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) if __name__ == "__main__": |