aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops_test.py')
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 59bb925df0..153548ae92 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -939,7 +939,7 @@ class CaseTest(test_util.TensorFlowTestCase):
class WhileLoopTestCase(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWhileLoopWithSingleVariable(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
@@ -948,7 +948,7 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
self.assertEqual(self.evaluate(r), 10)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
@@ -958,6 +958,28 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
# Expect a tuple since that is what the body returns.
self.assertEqual(self.evaluate(r), (10,))
+ def testWhileLoopSameReturnShape_False(self):
+ i = constant_op.constant(0)
+ c = lambda i, _: math_ops.less(i, 10)
+
+ # Body returns a [tensor, []]
+ b = lambda i, _: [math_ops.add(i, 1), []]
+
+ # Should only return the tensor.
+ r = control_flow_ops.while_loop(c, b, [i, []])
+ self.assertEqual(self.evaluate(r), 10)
+
+ def testWhileLoopSameReturnShape_True(self):
+ i = constant_op.constant(0)
+ c = lambda i, _: math_ops.less(i, 10)
+
+ # Body returns a [tensor, []]
+ b = lambda i, _: [math_ops.add(i, 1), []]
+
+ # Should only return the original structure.
+ r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True)
+ self.assertEqual(self.evaluate(r), [10, []])
+
if __name__ == "__main__":
googletest.main()