aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/reshape_op_test.py
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-06-25 14:45:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-25 16:02:34 -0700
commit706a5baa6e633ffbbcdf49f69e3ef88421001a76 (patch)
tree800be06b9e1fb7ac544c0e9c2cf1cb12f4084693 /tensorflow/python/kernel_tests/reshape_op_test.py
parent29a53ccbf6c3b5d94c8c8370e3163312e1df47bd (diff)
Add partial shape inference for values that are used as shapes.
With this change, the shape inference for `tf.reshape()` will correctly observe that, for example: ```python x = tf.placeholder(tf.float32) y = tf.placeholder(tf.int32) z = tf.reshape(x, [y, 37]) print(z.get_shape()) # ==> (?, 37) ``` Partially addresses #2938. Change: 125875146
Diffstat (limited to 'tensorflow/python/kernel_tests/reshape_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index 0487621b46..a68f722244 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -99,11 +99,6 @@ class ReshapeTest(tf.test.TestCase):
self._testBothReshape(x, [1, -1, 5])
def testErrors(self):
- x = tf.constant(0.0, shape=[1, 0, 3])
- with self.assertRaisesRegexp(
- ValueError, "cannot infer the missing input size"):
- tf.reshape(x, [0, -1, 5])
-
y = tf.constant(0.0, shape=[23, 29, 31])
with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"):
tf.reshape(y, [17, -1])
@@ -128,6 +123,20 @@ class ReshapeTest(tf.test.TestCase):
y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,)))
self.assertEqual([None, None, None], y.get_shape().as_list())
+ # Unknown input shape, partial new shape using `tf.pack()`.
+ y = tf.reshape(x, [tf.placeholder(tf.int32), 37])
+ self.assertEqual([None, 37], y.get_shape().as_list())
+
+ # Unknown input shape, partial new shape using `tf.concat()`.
+ y = tf.reshape(x, tf.concat(0, [tf.placeholder(tf.int32, shape=(2,)),
+ [37, 42]]))
+ self.assertEqual([None, None, 37, 42], y.get_shape().as_list())
+
+ # Unknown input shape, partial new shape using `tf.shape()`.
+ y = tf.reshape(x, tf.shape(tf.placeholder(tf.float32,
+ shape=[None, 37, None])))
+ self.assertEqual([None, 37, None], y.get_shape().as_list())
+
if __name__ == "__main__":
tf.test.main()