aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-02-06 12:04:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-06 12:45:32 -0800
commit87b29b7f3768328f02170aeb4338dd232be00248 (patch)
tree4460cca23b69d647125466abc1b77021dd522637 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py
parentd9df4313a98fdc62187a94c5ab6d8955b699e9f2 (diff)
Second, cleaner, attempt at external control dependency handling.
PiperOrigin-RevId: 184718016
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 4fafc36014..15ff0ec09b 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -704,6 +704,36 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ def testWhileExternalControlDependencies(self):
+ with self.test_session():
+ v = variables.Variable(0.0)
+ v.initializer.run()
+ increment = v.assign_add(1.0)
+
+ def body_fn(i):
+ with ops.control_dependencies([increment]):
+ return i + i
+
+ result = control_flow_ops.while_loop(cond=lambda i: i < 1,
+ body=body_fn, loop_vars=[1])
+ result.eval()
+ self.assertAllEqual(v.eval(), 1.0)
+
+ def testWhileExternalControlDependenciesNoInput(self):
+ with self.test_session():
+ v = variables.Variable(0.0)
+ v.initializer.run()
+ increment = v.assign_add(1.0)
+
+ def body_fn(unused_i):
+ with ops.control_dependencies([increment]):
+ return constant_op.constant(5, name="five")
+
+ result = control_flow_ops.while_loop(cond=lambda i: i < 5,
+ body=body_fn, loop_vars=[0])
+ result.eval()
+ self.assertAllEqual(v.eval(), 1.0)
+
def testWhileWithRefs_1(self):
with self.test_session() as sess:
x = variables.Variable(0)._ref() # pylint: disable=protected-access