aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-02-05 13:03:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 13:08:15 -0800
commitf8f921c828fb2c97da7c7b80c01390ccec90ae40 (patch)
treef013f0fab629d01afbbf2eb42e7d9ac5715ead40 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentd0904cbe01c88332acb4faa8bede21adb5fa1de7 (diff)
Fixes issue where external control dependencies in while loops are dropped.
Fixes #15891 PiperOrigin-RevId: 184573795
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 4fafc36014..15ff0ec09b 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -704,6 +704,36 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ def testWhileExternalControlDependencies(self):
+ with self.test_session():
+ v = variables.Variable(0.0)
+ v.initializer.run()
+ increment = v.assign_add(1.0)
+
+ def body_fn(i):
+ with ops.control_dependencies([increment]):
+ return i + i
+
+ result = control_flow_ops.while_loop(cond=lambda i: i < 1,
+ body=body_fn, loop_vars=[1])
+ result.eval()
+ self.assertAllEqual(v.eval(), 1.0)
+
+ def testWhileExternalControlDependenciesNoInput(self):
+ with self.test_session():
+ v = variables.Variable(0.0)
+ v.initializer.run()
+ increment = v.assign_add(1.0)
+
+ def body_fn(unused_i):
+ with ops.control_dependencies([increment]):
+ return constant_op.constant(5, name="five")
+
+ result = control_flow_ops.while_loop(cond=lambda i: i < 5,
+ body=body_fn, loop_vars=[0])
+ result.eval()
+ self.assertAllEqual(v.eval(), 1.0)
+
def testWhileWithRefs_1(self):
with self.test_session() as sess:
x = variables.Variable(0)._ref() # pylint: disable=protected-access