aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-01-11 11:37:01 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-11 11:37:01 -0800
commita6037e933fe8799c80f94decc9f5b178b833c1b8 (patch)
tree25b51eb52664cc110fddf27b3f41af837859552e
parent60bccf654abc869b19b2ec58d021421ef9480445 (diff)
Add better shape inference for `tf.zeros_like()` and `tf.ones_like()`.
Previously, partial shape information was discarded, because our constant evaluation for (e.g.) `tf.shape(tf.placeholder([..., None, ...]))` could not produce a Numpy array for the shape. Since the *_like wrappers have access to the input tensor, we can use `Tensor.set_shape()` to add back the partial information. Fixes #744. Change: 111856452
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py10
-rw-r--r--tensorflow/python/ops/array_ops.py8
2 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 253244b412..a61c73e295 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -321,6 +321,11 @@ class ZerosLikeTest(tf.test.TestCase):
self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2)))
self.assertEqual([2, 3], z_var.get_shape())
+ def testZerosLikePartialShape(self):
+ d = tf.placeholder(tf.float32, shape=[None, 4, None])
+ z = tf.zeros_like(d)
+ self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list())
+
def testGenZerosLike(self):
for dtype in [tf.float32, tf.float64, tf.int32,
tf.uint8, tf.int16, tf.int8,
@@ -406,6 +411,11 @@ class OnesLikeTest(tf.test.TestCase):
self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2)))
self.assertEqual([2, 3], z_var.get_shape())
+ def testOnesLikePartialShape(self):
+ d = tf.placeholder(tf.float32, shape=[None, 4, None])
+ z = tf.zeros_like(d)
+ self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list())
+
def testGenOnesLike(self):
for dtype in [tf.float32, tf.float64, tf.int32,
tf.uint8, tf.int16, tf.int8,
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 0f36ed7e41..53c317183c 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -563,7 +563,9 @@ def zeros_like(tensor, dtype=None, name=None):
zeros_shape = shape(tensor)
if dtype is None:
dtype = tensor.dtype
- return zeros(zeros_shape, dtype=dtype, name=name)
+ ret = zeros(zeros_shape, dtype=dtype, name=name)
+ ret.set_shape(tensor.get_shape())
+ return ret
def ones_like(tensor, dtype=None, name=None):
@@ -594,7 +596,9 @@ def ones_like(tensor, dtype=None, name=None):
ones_shape = shape(tensor)
if dtype is None:
dtype = tensor.dtype
- return ones(ones_shape, dtype=dtype, name=name)
+ ret = ones(ones_shape, dtype=dtype, name=name)
+ ret.set_shape(tensor.get_shape())
+ return ret
def zeros_initializer(shape, dtype=dtypes.float32):