diff options
author | Derek Murray <mrry@google.com> | 2016-06-25 14:45:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-25 16:02:34 -0700 |
commit | 706a5baa6e633ffbbcdf49f69e3ef88421001a76 (patch) | |
tree | 800be06b9e1fb7ac544c0e9c2cf1cb12f4084693 /tensorflow/python/kernel_tests/reshape_op_test.py | |
parent | 29a53ccbf6c3b5d94c8c8370e3163312e1df47bd (diff) |
Add partial shape inference for values that are used as shapes.
With this change, the shape inference for `tf.reshape()` will
correctly observe that, for example:
```python
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.int32)
z = tf.reshape(x, [y, 37])
print(z.get_shape()) # ==> (?, 37)
```
Partially addresses #2938.
Change: 125875146
Diffstat (limited to 'tensorflow/python/kernel_tests/reshape_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/reshape_op_test.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index 0487621b46..a68f722244 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -99,11 +99,6 @@ class ReshapeTest(tf.test.TestCase): self._testBothReshape(x, [1, -1, 5]) def testErrors(self): - x = tf.constant(0.0, shape=[1, 0, 3]) - with self.assertRaisesRegexp( - ValueError, "cannot infer the missing input size"): - tf.reshape(x, [0, -1, 5]) - y = tf.constant(0.0, shape=[23, 29, 31]) with self.assertRaisesRegexp(ValueError, "isn't divisible by 17"): tf.reshape(y, [17, -1]) @@ -128,6 +123,20 @@ class ReshapeTest(tf.test.TestCase): y = tf.reshape(x, tf.placeholder(tf.int32, shape=(3,))) self.assertEqual([None, None, None], y.get_shape().as_list()) + # Unknown input shape, partial new shape using `tf.pack()`. + y = tf.reshape(x, [tf.placeholder(tf.int32), 37]) + self.assertEqual([None, 37], y.get_shape().as_list()) + + # Unknown input shape, partial new shape using `tf.concat()`. + y = tf.reshape(x, tf.concat(0, [tf.placeholder(tf.int32, shape=(2,)), + [37, 42]])) + self.assertEqual([None, None, 37, 42], y.get_shape().as_list()) + + # Unknown input shape, partial new shape using `tf.shape()`. + y = tf.reshape(x, tf.shape(tf.placeholder(tf.float32, + shape=[None, 37, None]))) + self.assertEqual([None, 37, None], y.get_shape().as_list()) + if __name__ == "__main__": tf.test.main() |