aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-12-08 10:46:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-08 10:53:39 -0800
commitec0f20465e4cac9b45e6bf840c29487911c76d3f (patch)
tree276f73115150b191805af7e098dd3aad861b4f49 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parent3a0dd455f0612a104ec81afb847615d21f4ccce0 (diff)
Fix tf.while_loop with maximum_iterations != None and single loop_var.
PiperOrigin-RevId: 178396322
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 51eb13b921..3a61d76f58 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -753,6 +753,15 @@ class ControlFlowTest(test.TestCase):
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1+3, 2+3, 3+3, 4+3, 5+3], r.eval())
+ def testWhileWithMaximumIterationsAndSingleArgument(self):
+ with self.test_session():
+ r = control_flow_ops.while_loop(
+ lambda i: i < 3,
+ lambda i: i + 1,
+ [0],
+ maximum_iterations=1)
+ self.assertEqual(1, r.eval())
+
# Have more than 10 parallel iterations and hence exercise k-bound
# most of the time.
def testWhile_3(self):
@@ -3014,6 +3023,16 @@ class EagerTest(test.TestCase):
self.assertAllEqual(isum(tensor, maximum_iterations=3).numpy(),
[1+3, 2+3, 3+3, 4+3, 5+3])
+ def testWhileWithMaximumIterationsAndSingleArgument(self):
+ with context.eager_mode():
+ tensor = constant_op.constant(0)
+ r = control_flow_ops.while_loop(
+ lambda i: i < 3,
+ lambda i: i + 1,
+ [tensor],
+ maximum_iterations=1)
+ self.assertEqual(1, r.numpy())
+
def testWithDependencies(self):
with context.eager_mode():
t1 = constant_op.constant(1)