aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops_test.py')
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 43fe045bcb..153548ae92 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -958,6 +958,28 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
# Expect a tuple since that is what the body returns.
self.assertEqual(self.evaluate(r), (10,))
+ def testWhileLoopSameReturnShape_False(self):
+ i = constant_op.constant(0)
+ c = lambda i, _: math_ops.less(i, 10)
+
+ # Body returns a [tensor, []]
+ b = lambda i, _: [math_ops.add(i, 1), []]
+
+ # Should only return the tensor.
+ r = control_flow_ops.while_loop(c, b, [i, []])
+ self.assertEqual(self.evaluate(r), 10)
+
+ def testWhileLoopSameReturnShape_True(self):
+ i = constant_op.constant(0)
+ c = lambda i, _: math_ops.less(i, 10)
+
+ # Body returns a [tensor, []]
+ b = lambda i, _: [math_ops.add(i, 1), []]
+
+ # Should only return the original structure.
+ r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True)
+ self.assertEqual(self.evaluate(r), [10, []])
+
if __name__ == "__main__":
googletest.main()