diff options
author | 2017-01-31 17:06:14 -0800 | |
---|---|---|
committer | 2017-01-31 17:28:54 -0800 | |
commit | 88e1ca90f940a9ff45a6fa8fa8c5ff2c7e65cbc2 (patch) | |
tree | 160e64dd698e73266a0ced0b7121464a17c6153b | |
parent | 87a5793fffbe5ac884f19e608fc8e9b938764fbc (diff) |
Add string support to tf.zeros.
Change: 146185845
-rw-r--r-- | tensorflow/python/kernel_tests/constant_op_test.py | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 7 |
2 files changed, 29 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 935622d1f4..e502d58895 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -325,16 +325,20 @@ class ZerosTest(test.TestCase): dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64, - dtypes_lib.bool + dtypes_lib.bool, dtypes_lib.string ]: z = array_ops.zeros([2, 3], dtype=dtype) self.assertEqual(z.dtype, dtype) self.assertEqual([2, 3], z.get_shape()) - self.assertAllEqual(z.eval(), np.zeros([2, 3])) + z_value = z.eval() + self.assertFalse(np.any(z_value)) + self.assertEqual((2, 3), z_value.shape) z = array_ops.zeros(array_ops.shape(d), dtype=dtype) self.assertEqual(z.dtype, dtype) self.assertEqual([2, 3], z.get_shape()) - self.assertAllEqual(z.eval(), np.zeros([2, 3])) + z_value = z.eval() + self.assertFalse(np.any(z_value)) + self.assertEqual((2, 3), z_value.shape) class ZerosLikeTest(test.TestCase): @@ -342,30 +346,40 @@ class ZerosLikeTest(test.TestCase): def _compareZeros(self, dtype, use_gpu): with self.test_session(use_gpu=use_gpu): # Creates a tensor of non-zero values with shape 2 x 3. - numpy_dtype = dtype.as_numpy_dtype + # NOTE(kearnes): The default numpy dtype associated with tf.string is + # np.object (and can't be changed without breaking a lot things), which + # causes a TypeError in constant_op.constant below. Here we catch the + # special case of tf.string and set the numpy dtype appropriately. + if dtype == dtypes_lib.string: + numpy_dtype = np.string_ + else: + numpy_dtype = dtype.as_numpy_dtype d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype) # Constructs a tensor of zeros of the same dimensions and type as "d". z_var = array_ops.zeros_like(d) # Test that the type is correct self.assertEqual(z_var.dtype, dtype) - z_value = z_var.eval() + # Test that the shape is correct + self.assertEqual([2, 3], z_var.get_shape()) # Test that the value is correct - self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2))) - self.assertEqual([2, 3], z_var.get_shape()) + z_value = z_var.eval() + self.assertFalse(np.any(z_value)) + self.assertEqual((2, 3), z_value.shape) def testZerosLikeCPU(self): for dtype in [ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, - dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64 + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64, + dtypes_lib.string ]: self._compareZeros(dtype, False) def testZerosLikeGPU(self): for dtype in [ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, - dtypes_lib.bool, dtypes_lib.int64 + dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string ]: self._compareZeros(dtype, True) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index fc47fc325f..d0678830a4 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1370,7 +1370,12 @@ def zeros(shape, dtype=dtypes.float32, name=None): """ dtype = dtypes.as_dtype(dtype).base_dtype with ops.name_scope(name, "zeros", [shape]) as name: - zero = False if dtype == dtypes.bool else 0 + if dtype == dtypes.bool: + zero = False + elif dtype == dtypes.string: + zero = "" + else: + zero = 0 try: shape = tensor_shape.as_shape(shape) output = constant(zero, shape=shape, dtype=dtype, name=name) |