aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-11 09:41:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-11 10:49:56 -0700
commitd3aa99075fdd9b5ce0d65537e12af4e3a448d18a (patch)
treec590535e54f5f8612e29ecdaacef419f0534ab74
parent5b79de86364db353e3fe68d1612f1465bad9fca0 (diff)
Fixes tf.cond with nested return values inside while loops.
Change: 152830112
-rw-r--r--tensorflow/python/ops/control_flow_ops.py3
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py16
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index dea2180069..93416140ff 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1416,8 +1416,7 @@ class ControlFlowContext(object):
def ExitResult(self, result):
"""Make a list of tensors available in the outer context."""
if self._outer_context:
- for x in result:
- self._outer_context.AddName(x.name)
+ nest.map_structure(lambda x: self._outer_context.AddName(x.name), result)
def GetWhileContext(self):
"""Return the while context containing this context."""
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index a88143224f..21a5afabe0 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -678,6 +678,22 @@ class DataTypesTest(TensorFlowTestCase):
[1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
[11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16])
+ def test_cond_inside_while_loop(self):
+ def Body(i, matrix):
+ result_tuple, unused_matrix = control_flow_ops.cond(
+ constant_op.constant(True),
+ lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
+ lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
+ return [i+1, result_tuple.a]
+
+ iteration, matrix = control_flow_ops.while_loop(
+ lambda i, matrix: i < 10,
+ Body,
+ loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])])
+
+ self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([]))
+ self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
+
if __name__ == "__main__":
googletest.main()