diff options
author | Derek Murray <mrry@google.com> | 2016-01-11 11:37:01 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-11 11:37:01 -0800 |
commit | a6037e933fe8799c80f94decc9f5b178b833c1b8 (patch) | |
tree | 25b51eb52664cc110fddf27b3f41af837859552e | |
parent | 60bccf654abc869b19b2ec58d021421ef9480445 (diff) |
Add better shape inference for `tf.zeros_like()` and `tf.ones_like()`.
Previously, partial shape information was discarded, because our
constant evaluation for (e.g.) `tf.shape(tf.placeholder([..., None,
...]))` could not produce a Numpy array for the shape. Since the
*_like wrappers have access to the input tensor, we can use
`Tensor.set_shape()` to add back the partial information.
Fixes #744.
Change: 111856452
-rw-r--r-- | tensorflow/python/kernel_tests/constant_op_test.py | 10 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 8 |
2 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 253244b412..a61c73e295 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -321,6 +321,11 @@ class ZerosLikeTest(tf.test.TestCase): self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2))) self.assertEqual([2, 3], z_var.get_shape()) + def testZerosLikePartialShape(self): + d = tf.placeholder(tf.float32, shape=[None, 4, None]) + z = tf.zeros_like(d) + self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list()) + def testGenZerosLike(self): for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, @@ -406,6 +411,11 @@ class OnesLikeTest(tf.test.TestCase): self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2))) self.assertEqual([2, 3], z_var.get_shape()) + def testOnesLikePartialShape(self): + d = tf.placeholder(tf.float32, shape=[None, 4, None]) + z = tf.zeros_like(d) + self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list()) + def testGenOnesLike(self): for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 0f36ed7e41..53c317183c 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -563,7 +563,9 @@ def zeros_like(tensor, dtype=None, name=None): zeros_shape = shape(tensor) if dtype is None: dtype = tensor.dtype - return zeros(zeros_shape, dtype=dtype, name=name) + ret = zeros(zeros_shape, dtype=dtype, name=name) + ret.set_shape(tensor.get_shape()) + return ret def ones_like(tensor, dtype=None, name=None): @@ -594,7 +596,9 @@ def ones_like(tensor, dtype=None, name=None): ones_shape = shape(tensor) if dtype is None: dtype = tensor.dtype - return ones(ones_shape, dtype=dtype, name=name) + ret = ones(ones_shape, dtype=dtype, name=name) + ret.set_shape(tensor.get_shape()) + return ret def zeros_initializer(shape, dtype=dtypes.float32): |