aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/nullary_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/nullary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py43
1 files changed, 31 insertions, 12 deletions
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index f985c5d2d9..38cb2f83ef 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase):
output.run()
def testConstants(self):
- constants = [
- np.float32(42),
- np.array([], dtype=np.float32),
- np.array([1, 2], dtype=np.float32),
- np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
- np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
- dtype=np.float32),
- np.array([[[]], [[]]], dtype=np.float32),
- np.array([[[[1]]]], dtype=np.float32),
- ]
- for c in constants:
- self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+ for dtype in self.numeric_types:
+ constants = [
+ dtype(42),
+ np.array([], dtype=dtype),
+ np.array([1, 2], dtype=dtype),
+ np.array([7, 7, 7, 7, 7], dtype=dtype),
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+
+ def testComplexConstants(self):
+ for dtype in self.complex_types:
+ constants = [
+ dtype(42 + 3j),
+ np.array([], dtype=dtype),
+ np.ones([50], dtype=dtype) * (3 + 4j),
+ np.array([1j, 2 + 1j], dtype=dtype),
+ np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4 + 6j], [5, 6]],
+ [[10 + 7j, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1 + 3j]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
if __name__ == "__main__":