aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-05-31 13:53:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-31 15:03:10 -0700
commitfd00c1ed27b781ae95878e662c2613bbdd840c67 (patch)
tree40eefe156b105c458fe4fce03167d956090f3691 /tensorflow/python
parent04ea051de751484988ba686993f9510d4742376c (diff)
Make shape inference of loop variables a bit smarter.
Change: 123686334
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py13
-rw-r--r--tensorflow/python/ops/control_flow_ops.py15
2 files changed, 26 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 80cba28ead..74768d429d 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -603,6 +603,19 @@ class ControlFlowTest(tf.test.TestCase):
r = r[1] * tf.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ def testWhileShapeInference(self):
+ with self.test_session():
+ i = tf.constant(0)
+ m = tf.ones([2, 2])
+ c = lambda i, j: tf.less(i, 2)
+ def _b(i, j):
+ new_i = tf.add(i, 1)
+ new_j = tf.concat(0, [j, j])
+ return [new_i, new_j]
+ r = tf.while_loop(c, _b, [i, m])
+ self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertEqual(r[1].get_shape()[1], tf.Dimension(2))
+
def _testNestedWhile_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
n = tf.constant(0)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e69b1cabd2..ba2d31b024 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -353,6 +353,18 @@ def _IsLoopExit(op):
return op.type == "Exit" or op.type == "RefExit"
+def _ShapeIntersection(shape1, shape2):
+ if shape1.dims is None or shape1.ndims != shape2.ndims:
+ return tensor_shape.unknown_shape()
+ rdims = []
+ for dim1, dim2 in zip(shape1.dims, shape2.dims):
+ if dim1 == dim2:
+ rdims.append(dim1)
+ else:
+ rdims.append(tensor_shape.Dimension(None))
+ return tensor_shape.TensorShape(rdims)
+
+
class GradLoopState(object):
"""The state used for constructing the gradient graph for a while loop.
@@ -1628,8 +1640,7 @@ class WhileContext(ControlFlowContext):
self._loop_exits = exit_vars
for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars):
- if not m_var.get_shape() == n_var.get_shape():
- e_var._shape = tensor_shape.unknown_shape()
+ e_var._shape = _ShapeIntersection(m_var.get_shape(), n_var.get_shape())
# Exit the loop.
self.ExitResult(exit_vars)