aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-31 17:06:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 17:28:54 -0800
commit88e1ca90f940a9ff45a6fa8fa8c5ff2c7e65cbc2 (patch)
tree160e64dd698e73266a0ced0b7121464a17c6153b
parent87a5793fffbe5ac884f19e608fc8e9b938764fbc (diff)
Add string support to tf.zeros.
Change: 146185845
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py32
-rw-r--r--tensorflow/python/ops/array_ops.py7
2 files changed, 29 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 935622d1f4..e502d58895 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -325,16 +325,20 @@ class ZerosTest(test.TestCase):
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
- dtypes_lib.bool
+ dtypes_lib.bool, dtypes_lib.string
]:
z = array_ops.zeros([2, 3], dtype=dtype)
self.assertEqual(z.dtype, dtype)
self.assertEqual([2, 3], z.get_shape())
- self.assertAllEqual(z.eval(), np.zeros([2, 3]))
+ z_value = z.eval()
+ self.assertFalse(np.any(z_value))
+ self.assertEqual((2, 3), z_value.shape)
z = array_ops.zeros(array_ops.shape(d), dtype=dtype)
self.assertEqual(z.dtype, dtype)
self.assertEqual([2, 3], z.get_shape())
- self.assertAllEqual(z.eval(), np.zeros([2, 3]))
+ z_value = z.eval()
+ self.assertFalse(np.any(z_value))
+ self.assertEqual((2, 3), z_value.shape)
class ZerosLikeTest(test.TestCase):
@@ -342,30 +346,40 @@ class ZerosLikeTest(test.TestCase):
def _compareZeros(self, dtype, use_gpu):
with self.test_session(use_gpu=use_gpu):
# Creates a tensor of non-zero values with shape 2 x 3.
- numpy_dtype = dtype.as_numpy_dtype
+ # NOTE(kearnes): The default numpy dtype associated with tf.string is
+ # np.object (and can't be changed without breaking a lot things), which
+ # causes a TypeError in constant_op.constant below. Here we catch the
+ # special case of tf.string and set the numpy dtype appropriately.
+ if dtype == dtypes_lib.string:
+ numpy_dtype = np.string_
+ else:
+ numpy_dtype = dtype.as_numpy_dtype
d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
# Constructs a tensor of zeros of the same dimensions and type as "d".
z_var = array_ops.zeros_like(d)
# Test that the type is correct
self.assertEqual(z_var.dtype, dtype)
- z_value = z_var.eval()
+ # Test that the shape is correct
+ self.assertEqual([2, 3], z_var.get_shape())
# Test that the value is correct
- self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2)))
- self.assertEqual([2, 3], z_var.get_shape())
+ z_value = z_var.eval()
+ self.assertFalse(np.any(z_value))
+ self.assertEqual((2, 3), z_value.shape)
def testZerosLikeCPU(self):
for dtype in [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
- dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
+ dtypes_lib.string
]:
self._compareZeros(dtype, False)
def testZerosLikeGPU(self):
for dtype in [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
- dtypes_lib.bool, dtypes_lib.int64
+ dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string
]:
self._compareZeros(dtype, True)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index fc47fc325f..d0678830a4 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1370,7 +1370,12 @@ def zeros(shape, dtype=dtypes.float32, name=None):
"""
dtype = dtypes.as_dtype(dtype).base_dtype
with ops.name_scope(name, "zeros", [shape]) as name:
- zero = False if dtype == dtypes.bool else 0
+ if dtype == dtypes.bool:
+ zero = False
+ elif dtype == dtypes.string:
+ zero = ""
+ else:
+ zero = 0
try:
shape = tensor_shape.as_shape(shape)
output = constant(zero, shape=shape, dtype=dtype, name=name)