diff options
author | Alexandre Passos <apassos@google.com> | 2018-02-05 13:03:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-05 13:08:15 -0800 |
commit | f8f921c828fb2c97da7c7b80c01390ccec90ae40 (patch) | |
tree | f013f0fab629d01afbbf2eb42e7d9ac5715ead40 /tensorflow/python/kernel_tests/control_flow_ops_py_test.py | |
parent | d0904cbe01c88332acb4faa8bede21adb5fa1de7 (diff) |
Fixes issue where external control dependencies in while loops are dropped.
Fixes #15891
PiperOrigin-RevId: 184573795
Diffstat (limited to 'tensorflow/python/kernel_tests/control_flow_ops_py_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 4fafc36014..15ff0ec09b 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -704,6 +704,36 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) self.assertEqual(10000, r.eval()) + def testWhileExternalControlDependencies(self): + with self.test_session(): + v = variables.Variable(0.0) + v.initializer.run() + increment = v.assign_add(1.0) + + def body_fn(i): + with ops.control_dependencies([increment]): + return i + i + + result = control_flow_ops.while_loop(cond=lambda i: i < 1, + body=body_fn, loop_vars=[1]) + result.eval() + self.assertAllEqual(v.eval(), 1.0) + + def testWhileExternalControlDependenciesNoInput(self): + with self.test_session(): + v = variables.Variable(0.0) + v.initializer.run() + increment = v.assign_add(1.0) + + def body_fn(unused_i): + with ops.control_dependencies([increment]): + return constant_op.constant(5, name="five") + + result = control_flow_ops.while_loop(cond=lambda i: i < 5, + body=body_fn, loop_vars=[0]) + result.eval() + self.assertAllEqual(v.eval(), 1.0) + def testWhileWithRefs_1(self): with self.test_session() as sess: x = variables.Variable(0)._ref() # pylint: disable=protected-access |