aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-03-21 08:25:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 08:27:54 -0700
commit56054e42a474a527f12f4d8d0b1f37eb1efd189d (patch)
treeb32483c3a3b4d49903354c02f4731ac65fe644b2 /tensorflow/contrib/framework
parent2a9387d771f4ba99ba09b197ede82a6ea9671af0 (diff)
[tf.contrib CriticalSection] Avoid deadlocks using additional control dependencies on the lock op.
PiperOrigin-RevId: 189910726
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_ops.py203
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_test.py143
2 files changed, 277 insertions, 69 deletions
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
index cc19372acf..1893d7b466 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py
@@ -24,10 +24,8 @@ import collections
# from tensorflow.core.protobuf import critical_section_pb2
from tensorflow.python.eager import context
-from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
@@ -48,6 +46,26 @@ class _ExecutionSignature(
pass
+def _identity(x):
+ """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
+ if isinstance(x, tensor_array_ops.TensorArray):
+ return x.identity()
+ elif isinstance(x, ops.Operation):
+ return control_flow_ops.group(x)
+ elif context.executing_eagerly() and x is None:
+ return None
+ else:
+ return array_ops.identity(x)
+
+
+def _get_colocation(op):
+ """Get colocation symbol from op, if any."""
+ try:
+ return op.get_attr("_class")
+ except ValueError:
+ return None
+
+
class CriticalSection(object):
"""Critical section.
@@ -180,8 +198,8 @@ class CriticalSection(object):
The tensors returned from `fn(*args, **kwargs)`.
Raises:
- ValueError: If `fn` attempts to use this `CriticalSection` in any nested
- way.
+ ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
+ or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access` is not provided (is `True`) and
another `CriticalSection` has an execution requesting the same
resources as in `*args`, `**kwargs`, and any additionaly captured
@@ -193,69 +211,52 @@ class CriticalSection(object):
exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)
with ops.name_scope(name, "critical_section_execute", []):
- lock = gen_resource_variable_ops.mutex_lock(self._handle)
-
- with ops.control_dependencies([lock]):
- c_known_ops = set()
- c_captured_tensors = set()
- def add_op_internal(op):
- c_known_ops.add(op)
- for i in op.inputs:
- if i.op not in c_known_ops:
- c_captured_tensors.add(i)
+ # Ensure that mutex locking only happens *after* all args and
+ # kwargs have been executed. This avoids certain types of deadlocks.
+ lock = gen_resource_variable_ops.mutex_lock(self._handle)
- c = function.HelperContext(add_op_internal)
- with c:
+ if not context.executing_eagerly():
+ # NOTE(ebrevdo): This is to ensure we don't pick up spurious
+ # Operations created by other threads.
+ with ops.get_default_graph()._lock: # pylint: disable=protected-access
+ existing_ops = ops.get_default_graph().get_operations()
+ with ops.control_dependencies([lock]):
+ r = fn(*args, **kwargs)
+ # TODO(ebrevdo): If creating critical sections in a python loop, this
+ # makes graph creation time quadratic. Revisit if this
+ # becomes a problem.
+ created_ops = (set(ops.get_default_graph().get_operations())
+ .difference(existing_ops))
+ else:
+ with ops.control_dependencies([lock]):
r = fn(*args, **kwargs)
- resource_inputs = set([
- x for x in
- list(nest.flatten(args)) + nest.flatten(kwargs.values()) +
- list(c_captured_tensors)
- if tensor_util.is_tensor(x) and x.dtype == dtypes.resource])
-
- if self._handle in resource_inputs:
- raise ValueError("The function fn attempts to access the "
- "CriticalSection in which it would be running. "
- "This is illegal and would cause deadlocks. "
- "CriticalSection: %s." % self._handle)
-
if not context.executing_eagerly():
- # Collections and op introspection does not work in eager
- # mode. This is generally ok; since eager mode (as of
- # writing) executes sequentially anyway.
- for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
- sg_handle_name = ops.convert_to_tensor(sg.handle).name
- self_handle_name = ops.convert_to_tensor(self._handle).name
- if sg_handle_name == self_handle_name:
- # Other executions in the same critical section are allowed.
- continue
- if not (exclusive_resource_access or sg.exclusive_resource_access):
- # Neither execution requested exclusive access.
- continue
- resource_intersection = resource_inputs.intersection(sg.resources)
- if resource_intersection:
- raise ValueError(
- "This execution would access resources: %s. Either this "
- "lock (CriticalSection: %s) or lock '%s' "
- "(CriticalSection: %s) requested exclusive resource access "
- "of this resource. Did you mean to call execute with keyword "
- "argument exclusive_resource_access=False?" %
- (list(resource_intersection), self._handle.name,
- sg.op.name, sg.handle.name))
-
- def identity(x): # pylint: disable=invalid-name
- if isinstance(x, tensor_array_ops.TensorArray):
- return x.identity()
- elif isinstance(x, ops.Operation):
- return control_flow_ops.group(x)
- elif context.executing_eagerly() and x is None:
- return None
- else:
- return array_ops.identity(x)
-
- r_flat = [identity(x) for x in nest.flatten(r)]
+ self._add_control_dependencies_to_lock(created_ops, lock.op)
+
+ # captured_resources is a list of resources that are directly
+ # accessed only by ops created during fn(), not by any
+ # ancestors of those ops in the graph.
+ captured_resources = set([
+ input_ for op in created_ops
+ for input_ in op.inputs
+ if input_.dtype == dtypes.resource
+ ])
+
+ # NOTE(ebrevdo): The only time self._is_self_handle() is True
+ # in this call is if one of the recently created ops, within
+ # the execute(), themselves attempt to access the
+ # CriticalSection. This will cause a deadlock.
+ if any(self._is_self_handle(x) for x in captured_resources):
+ raise ValueError("The function fn attempts to directly access the "
+ "CriticalSection in which it would be running. "
+ "This is illegal and would cause deadlocks.")
+
+ self._check_multiple_access_to_resources(
+ captured_resources, exclusive_resource_access)
+
+ r_flat = [_identity(x) for x in nest.flatten(r)]
with ops.control_dependencies(r_flat):
# The identity must run on the same machine as self._handle
@@ -268,23 +269,93 @@ class CriticalSection(object):
# Make sure that if any element of r is accessed, all of
# them are executed together.
- r = nest.pack_sequence_as(
- r, control_flow_ops.tuple(nest.flatten(r)))
+ r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
with ops.control_dependencies([ensure_lock_exists]):
- outputs = nest.map_structure(identity, r)
+ outputs = nest.map_structure(_identity, r)
if not context.executing_eagerly():
signature = _ExecutionSignature(
op=lock.op,
handle=self._handle,
- resources=list(resource_inputs),
+ resources=list(captured_resources),
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
return outputs
+ def _add_control_dependencies_to_lock(self, created_ops, lock_op):
+ """To avoid deadlocks, all args must be executed before lock_op."""
+ # Get all arguments (explicit and captured) of all ops created by fn().
+ all_args = set([input_.op for op in created_ops for input_ in op.inputs])
+ all_args.update(
+ input_op for op in created_ops for input_op in op.control_inputs)
+ # Unfortunately, we can't use sets throughout because TF seems to
+ # create new Operation objects for the same op sometimes; and we
+ # can't rely on id(op).
+
+ # pylint: disable=protected-access
+ all_args_dict = dict((op._id, op) for op in all_args)
+
+ # Remove ops created within fn, or that lock_op already has a
+ # control dependency on. Also remove a possible self-loop.
+ for op in created_ops:
+ all_args_dict.pop(op._id, None)
+ for op in lock_op.control_inputs:
+ all_args_dict.pop(op._id, None)
+ for input_ in lock_op.inputs:
+ all_args_dict.pop(input_.op._id, None)
+ all_args_dict.pop(lock_op._id, None)
+
+ lock_op._add_control_inputs(all_args_dict.values())
+ # pylint: enable=protected-access
+
+ def _is_self_handle(self, x):
+ """Check if the tensor `x` is the same Mutex as `self._handle`."""
+ return (x.op.type == "MutexV2"
+ # blank shared_name means the op will create a unique one.
+ and x.op.get_attr("shared_name")
+ and (x.op.get_attr("shared_name") ==
+ self._handle.op.get_attr("shared_name"))
+ and (x.op.device == self._handle.op.device
+ or _get_colocation(x.op) == _get_colocation(self._handle.op)))
+
+ def _check_multiple_access_to_resources(
+ self, captured_resources, exclusive_resource_access):
+ """Raise if captured_resources are accessed by another CriticalSection.
+
+ Args:
+ captured_resources: Set of tensors of type resource.
+ exclusive_resource_access: Whether this execution requires exclusive
+ resource access.
+
+ Raises:
+ ValueError: If any tensors in `captured_resources` are also accessed
+ by another `CriticalSection`, and at least one of them requires
+ exclusive resource access.
+ """
+ # Collections and op introspection does not work in eager
+ # mode. This is generally ok; since eager mode (as of
+ # writing) executes sequentially anyway.
+ for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
+ if self._is_self_handle(sg.handle):
+ # Other executions in the same critical section are allowed.
+ continue
+ if not (exclusive_resource_access or sg.exclusive_resource_access):
+ # Neither execution requested exclusive access.
+ continue
+ resource_intersection = captured_resources.intersection(sg.resources)
+ if resource_intersection:
+ raise ValueError(
+ "This execution would access resources: %s. Either this "
+ "lock (CriticalSection: %s) or lock '%s' "
+ "(CriticalSection: %s) requested exclusive resource access "
+ "of this resource. Did you mean to call execute with keyword "
+ "argument exclusive_resource_access=False?" %
+ (list(resource_intersection), self._handle.name,
+ sg.op.name, sg.handle.name))
+
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def to_proto(self, export_scope=None):
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py
index c916592ce1..e24140bd72 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.python.training import saver as saver_lib
@@ -37,7 +38,7 @@ class CriticalSectionTest(test.TestCase):
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def fn(a, b):
- c = v.read_value()
+ c = v.value()
with ops.control_dependencies([c]):
nv = v.assign_add(a * b)
with ops.control_dependencies([nv]):
@@ -143,12 +144,148 @@ class CriticalSectionTest(test.TestCase):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
# to debug.
+ cs = critical_section_ops.CriticalSection()
+ def fn(x):
+ return cs.execute(lambda y: y + 1, x)
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"attempts to directly access the CriticalSection in which it "
+ r"would be running"):
+ cs.execute(fn, 1.0)
+
+ def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
+ # This one is subtle; and we're being overly cautious here. The
+ # deadlock we are ensuring we catch is:
+ #
+ # to_capture = CS[lambda x: x + 1](1.0)
+ # deadlocked = CS[lambda x: x + to_capture](1.0)
+ #
+ # This would have caused a deadlock because executing `deadlocked` will
+ # lock the mutex on CS; but then due to dependencies, will attempt
+ # to compute `to_capture`. This computation requires locking CS,
+ # but that is not possible now because CS is already locked by
+ # `deadlocked`.
+ #
+ # We check that CriticalSection.execute properly inserts new
+ # control dependencies to its lock to ensure all captured
+ # operations are finished before anything runs within the critical section.
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
+ fn = array_ops.identity
+ to_capture = cs.execute(fn, 1.0)
+ fn_captures = lambda x: x + to_capture
+ to_capture_too = array_ops.identity(to_capture)
+
+ ex_0 = cs.execute(fn_captures, 1.0)
+
+ with ops.control_dependencies([to_capture]):
+ # This is OK because to_capture will execute before this next call
+ ex_1 = cs.execute(fn_captures, 1.0)
+
+ dependency = array_ops.identity(to_capture)
+
+ fn_captures_dependency = lambda x: x + dependency
+
+ ex_2 = cs.execute(fn_captures_dependency, 1.0)
+
+ with ops.control_dependencies([to_capture_too]):
+ ex_3 = cs.execute(fn_captures_dependency, 1.0)
+
+ # Ensure there's no actual deadlock on to_execute.
+ self.assertEquals(2.0, self.evaluate(ex_0))
+ self.assertEquals(2.0, self.evaluate(ex_1))
+ self.assertEquals(2.0, self.evaluate(ex_2))
+ self.assertEquals(2.0, self.evaluate(ex_3))
+
+ def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
+ cs = critical_section_ops.CriticalSection(shared_name="cs")
+
+ def body_implicit_capture(i, j):
+ # This would have caused a deadlock if not for logic in execute
+ # that inserts additional control dependencies onto the lock op:
+ # * Loop body argument j is captured by fn()
+ # * i is running in parallel to move forward the execution
+ # * j is not being checked by the predicate function
+ # * output of cs.execute() is returned as next j.
+ fn = lambda: j + 1
+ return (i + 1, cs.execute(fn))
+
+ (i_n, j_n) = control_flow_ops.while_loop(
+ lambda i, _: i < 1000,
+ body_implicit_capture,
+ [0, 0],
+ parallel_iterations=25)
+ logging.warn(
+ "\n==============\nRunning "
+ "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+ "body_implicit_capture'\n"
+ "==============\n")
+ self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
+ logging.warn(
+ "\n==============\nSuccessfully finished running "
+ "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+ "body_implicit_capture'\n"
+ "==============\n")
+
+ def body_implicit_capture_protected(i, j):
+ # This version is ok because we manually add a control
+ # dependency on j, which is an argument to the while_loop body
+ # and captured by fn.
+ fn = lambda: j + 1
+ with ops.control_dependencies([j]):
+ return (i + 1, cs.execute(fn))
+
+ (i_n, j_n) = control_flow_ops.while_loop(
+ lambda i, _: i < 1000,
+ body_implicit_capture_protected,
+ [0, 0],
+ parallel_iterations=25)
+ logging.warn(
+ "\n==============\nRunning "
+ "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+ "body_implicit_capture_protected'\n"
+ "==============\n")
+ self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
+ logging.warn(
+ "\n==============\nSuccessfully finished running "
+ "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+ "body_implicit_capture_protected'\n"
+ "==============\n")
+
+ def body_args_capture(i, j):
+ # This version is ok because j is an argument to fn and we can
+ # ensure there's a control dependency on j.
+ fn = lambda x: x + 1
+ return (i + 1, cs.execute(fn, j))
+
+ (i_n, j_n) = control_flow_ops.while_loop(
+ lambda i, _: i < 1000,
+ body_args_capture,
+ [0, 0],
+ parallel_iterations=25)
+ logging.warn(
+ "\n==============\nRunning "
+ "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+ "body_args_capture'\n"
+ "==============\n")
+ self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
+ logging.warn(
+ "\n==============\nSuccessfully finished running "
+ "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
+ "body_args_capture'\n"
+ "==============\n")
+
+ def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
+ # This does not work properly in eager mode. Eager users will
+ # just hit a deadlock if they do this. But at least it'll be easier
+ # to debug.
cs = critical_section_ops.CriticalSection(shared_name="cs")
+ cs_same = critical_section_ops.CriticalSection(shared_name="cs")
def fn(x):
- return cs.execute(lambda x: x+1, x)
+ return cs_same.execute(lambda x: x+1, x)
with self.assertRaisesRegexp(
ValueError,
- r"attempts to access the CriticalSection in which it would be running"):
+ r"attempts to directly access the CriticalSection in which it "
+ r"would be running"):
cs.execute(fn, 1.0)
def testMultipleCSExecutionsRequestSameResource(self):