diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-03-21 13:28:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-21 13:31:07 -0700 |
commit | 56e5181f340f855e0eef9a4ce25baea5be1aaebc (patch) | |
tree | c7f506aa50eccf805e6a95264c1d8ab448f55419 /tensorflow/contrib/framework | |
parent | 6741f81b8216862a83703122191a8632fda333a2 (diff) |
[TF CriticalSection] Bugfix when Execute() inside a while_loop has a dep on a Variable outside of it.
PiperOrigin-RevId: 189957569
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/critical_section_ops.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/framework/python/ops/critical_section_test.py | 14 |
2 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 1893d7b466..bd764ed57a 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -308,7 +308,19 @@ class CriticalSection(object): all_args_dict.pop(input_.op._id, None) all_args_dict.pop(lock_op._id, None) - lock_op._add_control_inputs(all_args_dict.values()) + all_args = all_args_dict.values() + + if not all_args: + # No control dependencies to add; return early. + return + + # This group is important: it ensures that any ops in all_args + # outside the control context of the lock_op (and this fn, which + # runs in the same context) are added to this context before + # being added to the control dependencies of lock_op. + all_args = control_flow_ops.group(*all_args) + + lock_op._add_control_input(all_args) # pylint: enable=protected-access def _is_self_handle(self, x): diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index e24140bd72..ba660295cb 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -316,6 +316,20 @@ class CriticalSectionTest(test.TestCase): ValueError, "requested exclusive resource access"): cs1.execute(lambda: v2 + 1) + def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0, name="v") + # Make sure that the control dependencies on v do not cause issues + # in the lock_op's automatic control dependency adder. + # + # Note, here v must be a resource variable (or something similar), + # otherwise it gets hoisted into the while_loop by the time we add + # control dependencies to the lock_op. + out = control_flow_ops.while_loop( + lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) + self.evaluate(v.initializer) + self.assertEqual(10, self.evaluate(out)) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): |