diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-12-08 10:46:46 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-08 10:53:39 -0800 |
commit | ec0f20465e4cac9b45e6bf840c29487911c76d3f (patch) | |
tree | 276f73115150b191805af7e098dd3aad861b4f49 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | 3a0dd455f0612a104ec81afb847615d21f4ccce0 (diff) |
Fix tf.while_loop with maximum_iterations != None and single loop_var.
PiperOrigin-RevId: 178396322
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 51eb13b921..3a61d76f58 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -753,6 +753,15 @@ class ControlFlowTest(test.TestCase): r = isum(s, maximum_iterations=3) self.assertAllEqual([1+3, 2+3, 3+3, 4+3, 5+3], r.eval()) + def testWhileWithMaximumIterationsAndSingleArgument(self): + with self.test_session(): + r = control_flow_ops.while_loop( + lambda i: i < 3, + lambda i: i + 1, + [0], + maximum_iterations=1) + self.assertEqual(1, r.eval()) + # Have more than 10 parallel iterations and hence exercise k-bound # most of the time. def testWhile_3(self): @@ -3014,6 +3023,16 @@ class EagerTest(test.TestCase): self.assertAllEqual(isum(tensor, maximum_iterations=3).numpy(), [1+3, 2+3, 3+3, 4+3, 5+3]) + def testWhileWithMaximumIterationsAndSingleArgument(self): + with context.eager_mode(): + tensor = constant_op.constant(0) + r = control_flow_ops.while_loop( + lambda i: i < 3, + lambda i: i + 1, + [tensor], + maximum_iterations=1) + self.assertEqual(1, r.numpy()) + def testWithDependencies(self): with context.eager_mode(): t1 = constant_op.constant(1) |