aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-03-21 13:28:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 13:31:07 -0700
commit56e5181f340f855e0eef9a4ce25baea5be1aaebc (patch)
treec7f506aa50eccf805e6a95264c1d8ab448f55419 /tensorflow/contrib/framework
parent6741f81b8216862a83703122191a8632fda333a2 (diff)
[TF CriticalSection] Bugfix when Execute() inside a while_loop has a dep on a Variable outside of it.
PiperOrigin-RevId: 189957569
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_ops.py14
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_test.py14
2 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
index 1893d7b466..bd764ed57a 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
@@ -308,7 +308,19 @@ class CriticalSection(object):
all_args_dict.pop(input_.op._id, None)
all_args_dict.pop(lock_op._id, None)
- lock_op._add_control_inputs(all_args_dict.values())
+ all_args = all_args_dict.values()
+
+ if not all_args:
+ # No control dependencies to add; return early.
+ return
+
+ # This group is important: it ensures that any ops in all_args
+ # outside the control context of the lock_op (and this fn, which
+ # runs in the same context) are added to this context before
+ # being added to the control dependencies of lock_op.
+ all_args = control_flow_ops.group(*all_args)
+
+ lock_op._add_control_input(all_args)
# pylint: enable=protected-access
def _is_self_handle(self, x):
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py
index e24140bd72..ba660295cb 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py
@@ -316,6 +316,20 @@ class CriticalSectionTest(test.TestCase):
ValueError, "requested exclusive resource access"):
cs1.execute(lambda: v2 + 1)
+ def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
+ cs = critical_section_ops.CriticalSection()
+ v = resource_variable_ops.ResourceVariable(0, name="v")
+ # Make sure that the control dependencies on v do not cause issues
+ # in the lock_op's automatic control dependency adder.
+ #
+ # Note, here v must be a resource variable (or something similar),
+ # otherwise it gets hoisted into the while_loop by the time we add
+ # control dependencies to the lock_op.
+ out = control_flow_ops.while_loop(
+ lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0])
+ self.evaluate(v.initializer)
+ self.assertEqual(10, self.evaluate(out))
+
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
#
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):