diff options
author | 2017-04-11 09:41:14 -0800 | |
---|---|---|
committer | 2017-04-11 10:49:56 -0700 | |
commit | d3aa99075fdd9b5ce0d65537e12af4e3a448d18a (patch) | |
tree | c590535e54f5f8612e29ecdaacef419f0534ab74 | |
parent | 5b79de86364db353e3fe68d1612f1465bad9fca0 (diff) |
Fixes tf.cond with nested return values inside while loops.
Change: 152830112
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops_test.py | 16 |
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index dea2180069..93416140ff 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1416,8 +1416,7 @@ class ControlFlowContext(object): def ExitResult(self, result): """Make a list of tensors available in the outer context.""" if self._outer_context: - for x in result: - self._outer_context.AddName(x.name) + nest.map_structure(lambda x: self._outer_context.AddName(x.name), result) def GetWhileContext(self): """Return the while context containing this context.""" diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index a88143224f..21a5afabe0 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -678,6 +678,22 @@ class DataTypesTest(TensorFlowTestCase): [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], [11, TestTuple(12, [13, 14]), np.ones([5, 5]), 16]) + def test_cond_inside_while_loop(self): + def Body(i, matrix): + result_tuple, unused_matrix = control_flow_ops.cond( + constant_op.constant(True), + lambda: (TestTuple(matrix * 2, matrix * 4), matrix), + lambda: (TestTuple(matrix * 4, matrix * 2), matrix)) + return [i+1, result_tuple.a] + + iteration, matrix = control_flow_ops.while_loop( + lambda i, matrix: i < 10, + Body, + loop_vars=[constant_op.constant(0), array_ops.ones([2, 2])]) + + self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) + self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) + if __name__ == "__main__": googletest.main() |