diff options
author | Yuan Yu <yuanbyu@google.com> | 2016-05-31 13:53:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-31 15:03:10 -0700 |
commit | fd00c1ed27b781ae95878e662c2613bbdd840c67 (patch) | |
tree | 40eefe156b105c458fe4fce03167d956090f3691 | |
parent | 04ea051de751484988ba686993f9510d4742376c (diff) |
Make shape inference of loop variables a bit smarter.
Change: 123686334
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 13 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 15 |
2 files changed, 26 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 80cba28ead..74768d429d 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -603,6 +603,19 @@ class ControlFlowTest(tf.test.TestCase): r = r[1] * tf.ones([8, 8]) self.assertAllEqual(np.ones((8, 8)), r.eval()) + def testWhileShapeInference(self): + with self.test_session(): + i = tf.constant(0) + m = tf.ones([2, 2]) + c = lambda i, j: tf.less(i, 2) + def _b(i, j): + new_i = tf.add(i, 1) + new_j = tf.concat(0, [j, j]) + return [new_i, new_j] + r = tf.while_loop(c, _b, [i, m]) + self.assertTrue(r[1].get_shape()[0].value is None) + self.assertEqual(r[1].get_shape()[1], tf.Dimension(2)) + def _testNestedWhile_1(self, use_gpu): with self.test_session(use_gpu=use_gpu): n = tf.constant(0) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index e69b1cabd2..ba2d31b024 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -353,6 +353,18 @@ def _IsLoopExit(op): return op.type == "Exit" or op.type == "RefExit" +def _ShapeIntersection(shape1, shape2): + if shape1.dims is None or shape1.ndims != shape2.ndims: + return tensor_shape.unknown_shape() + rdims = [] + for dim1, dim2 in zip(shape1.dims, shape2.dims): + if dim1 == dim2: + rdims.append(dim1) + else: + rdims.append(tensor_shape.Dimension(None)) + return tensor_shape.TensorShape(rdims) + + class GradLoopState(object): """The state used for constructing the gradient graph for a while loop. @@ -1628,8 +1640,7 @@ class WhileContext(ControlFlowContext): self._loop_exits = exit_vars for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars): - if not m_var.get_shape() == n_var.get_shape(): - e_var._shape = tensor_shape.unknown_shape() + e_var._shape = _ShapeIntersection(m_var.get_shape(), n_var.get_shape()) # Exit the loop. self.ExitResult(exit_vars) |