aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/ops/critical_section_ops.py
blob: bd764ed57a6da0a4d356235108e998a80ac34362 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critical Section object and execution logic."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections

# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.core.protobuf import critical_section_pb2

from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest


# Graph Keys
CRITICAL_SECTIONS = "critical_sections"
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"


class _ExecutionSignature(
    collections.namedtuple("_ExecutionSignature",
                           ("op", "handle",
                            "resources", "exclusive_resource_access"))):
  """A class storing an `ExecuteInCriticalResource` op and associated attrs."""
  pass


def _identity(x):
  """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
  if isinstance(x, tensor_array_ops.TensorArray):
    return x.identity()
  elif isinstance(x, ops.Operation):
    return control_flow_ops.group(x)
  elif context.executing_eagerly() and x is None:
    return None
  else:
    return array_ops.identity(x)


def _get_colocation(op):
  """Get colocation symbol from op, if any."""
  try:
    return op.get_attr("_class")
  except ValueError:
    return None


class CriticalSection(object):
  """Critical section.

  A `CriticalSection` object is a resource in the graph which executes subgraphs
  in **serial** order.  A common example of a subgraph one may wish to run
  exclusively is the one given by the following function:

  ```python
  v = resource_variable_ops.ResourceVariable(0.0, name="v")

  def count():
    value = v.read_value()
    with tf.control_dependencies([value]):
      with tf.control_dependencies([v.assign_add(1)]):
        return tf.identity(value)
  ```

  Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
  The snapshot value is returned.

  If multiple workers or threads all execute `count` in parallel, there is no
  guarantee that access to the variable `v` is atomic at any point within
  any thread's calculation of `count`.  In fact, even implementing an atomic
  counter that guarantees that the user will see each value `0, 1, ...,` is
  currently impossible.

  The solution is to ensure any access to the underlying resource `v` is
  only processed through a critical section:

  ```python
  cs = CriticalSection()
  f1 = cs.execute(count)
  f2 = cs.execute(count)
  output = f1 + f2
  session.run(output)
  ```
  The functions `f1` and `f2` will be executed serially, and updates to `v`
  will be atomic.

  **NOTES**

  All resource objects, including the critical section and any captured
  variables of functions executed on that critical section, will be
  colocated to the same device (host and cpu/gpu).

  When using multiple critical sections on the same resources, there is no
  guarantee of exclusive access to those resources.  This behavior is disallowed
  by default (but see the kwarg `exclusive_resource_access`).

  For example, running the same function in two separate critical sections
  will not ensure serial execution:

  ```python
  v = tf.get_variable("v", initializer=0.0, use_resource=True)
  def accumulate(up):
    x = v.read_value()
    with tf.control_dependencies([x]):
      with tf.control_dependencies([v.assign_add(up)]):
        return tf.identity(x)
  ex1 = CriticalSection().execute(
    accumulate, 1.0, exclusive_resource_access=False)
  ex2 = CriticalSection().execute(
    accumulate, 1.0, exclusive_resource_access=False)
  bad_sum = ex1 + ex2
  sess.run(v.initializer)
  sess.run(bad_sum)  # May return 0.0
  ```
  """

  def __init__(self, name=None, shared_name=None,
               critical_section_def=None, import_scope=None):
    """Creates a critical section."""
    if critical_section_def and name is not None:
      raise ValueError("critical_section_def and shared_name are "
                       "mutually exclusive.")
    if critical_section_def:
      self._init_from_proto(critical_section_def, import_scope=import_scope)
    else:
      self._init_from_args(name, shared_name)

  def _init_from_proto(self, critical_section_def, import_scope):  # pylint: disable=invalid-name
    raise NotImplementedError("Not yet implemented")
    # TODO(ebrevdo): Re-enable once CriticalSection is in core.
    # assert isinstance(
    #     critical_section_def, critical_section_pb2.CriticalSectionDef)
    # # Create from critical_section_def.
    # g = ops.get_default_graph()
    # self._handle = g.as_graph_element(
    #     ops.prepend_name_scope(
    #         critical_section_def.critical_section_name,
    #         import_scope=import_scope))

  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
    """Initialize the CriticalSection from constructor arguments."""
    with ops.name_scope(name, "CriticalSection", []) as name:
      with ops.init_scope():
        # pylint: disable=protected-access
        container = ops.get_default_graph()._container
        # pylint: enable=protected-access
        if shared_name is None:
          shared_name = name
        if container is None:
          container = ""
        self._handle = gen_resource_variable_ops.mutex_v2(
            shared_name=shared_name, container=container, name=name)

    if not context.executing_eagerly():
      ops.add_to_collections(CRITICAL_SECTIONS, self)

  @property
  def name(self):
    return self._handle.op.name

  def execute(self, fn, *args, **kwargs):
    """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
        or lazy way that may cause a deadlock.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    name = kwargs.pop("name", None)
    exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)

    with ops.name_scope(name, "critical_section_execute", []):

      # Ensure that mutex locking only happens *after* all args and
      # kwargs have been executed.  This avoids certain types of deadlocks.
      lock = gen_resource_variable_ops.mutex_lock(self._handle)

      if not context.executing_eagerly():
        # NOTE(ebrevdo): This is to ensure we don't pick up spurious
        # Operations created by other threads.
        with ops.get_default_graph()._lock:  # pylint: disable=protected-access
          existing_ops = ops.get_default_graph().get_operations()
          with ops.control_dependencies([lock]):
            r = fn(*args, **kwargs)
          # TODO(ebrevdo): If creating critical sections in a python loop, this
          # makes graph creation time quadratic.  Revisit if this
          # becomes a problem.
          created_ops = (set(ops.get_default_graph().get_operations())
                         .difference(existing_ops))
      else:
        with ops.control_dependencies([lock]):
          r = fn(*args, **kwargs)

      if not context.executing_eagerly():
        self._add_control_dependencies_to_lock(created_ops, lock.op)

        # captured_resources is a list of resources that are directly
        # accessed only by ops created during fn(), not by any
        # ancestors of those ops in the graph.
        captured_resources = set([
            input_ for op in created_ops
            for input_ in op.inputs
            if input_.dtype == dtypes.resource
        ])

        # NOTE(ebrevdo): The only time self._is_self_handle() is True
        # in this call is if one of the recently created ops, within
        # the execute(), themselves attempt to access the
        # CriticalSection.  This will cause a deadlock.
        if any(self._is_self_handle(x) for x in captured_resources):
          raise ValueError("The function fn attempts to directly access the "
                           "CriticalSection in which it would be running.  "
                           "This is illegal and would cause deadlocks.")

        self._check_multiple_access_to_resources(
            captured_resources, exclusive_resource_access)

      r_flat = [_identity(x) for x in nest.flatten(r)]

      with ops.control_dependencies(r_flat):
        # The identity must run on the same machine as self._handle
        with ops.colocate_with(self._handle):
          # Do not use array_ops.identity as there are special
          # optimizations within TensorFlow which seem to elide it
          # even when optimizations are disabled(!).
          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
              lock)

        # Make sure that if any element of r is accessed, all of
        # them are executed together.
        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))

      with ops.control_dependencies([ensure_lock_exists]):
        outputs = nest.map_structure(_identity, r)

      if not context.executing_eagerly():
        signature = _ExecutionSignature(
            op=lock.op,
            handle=self._handle,
            resources=list(captured_resources),
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return outputs

  def _add_control_dependencies_to_lock(self, created_ops, lock_op):
    """To avoid deadlocks, all args must be executed before lock_op."""
    # Get all arguments (explicit and captured) of all ops created by fn().
    all_args = set([input_.op for op in created_ops for input_ in op.inputs])
    all_args.update(
        input_op for op in created_ops for input_op in op.control_inputs)
    # Unfortunately, we can't use sets throughout because TF seems to
    # create new Operation objects for the same op sometimes; and we
    # can't rely on id(op).

    # pylint: disable=protected-access
    all_args_dict = dict((op._id, op) for op in all_args)

    # Remove ops created within fn, or that lock_op already has a
    # control dependency on.  Also remove a possible self-loop.
    for op in created_ops:
      all_args_dict.pop(op._id, None)
    for op in lock_op.control_inputs:
      all_args_dict.pop(op._id, None)
    for input_ in lock_op.inputs:
      all_args_dict.pop(input_.op._id, None)
    all_args_dict.pop(lock_op._id, None)

    all_args = all_args_dict.values()

    if not all_args:
      # No control dependencies to add; return early.
      return

    # This group is important: it ensures that any ops in all_args
    # outside the control context of the lock_op (and this fn, which
    # runs in the same context) are added to this context before
    # being added to the control dependencies of lock_op.
    all_args = control_flow_ops.group(*all_args)

    lock_op._add_control_input(all_args)
    # pylint: enable=protected-access

  def _is_self_handle(self, x):
    """Check if the tensor `x` is the same Mutex as `self._handle`."""
    return (x.op.type == "MutexV2"
            # blank shared_name means the op will create a unique one.
            and x.op.get_attr("shared_name")
            and (x.op.get_attr("shared_name") ==
                 self._handle.op.get_attr("shared_name"))
            and (x.op.device == self._handle.op.device
                 or _get_colocation(x.op) == _get_colocation(self._handle.op)))

  def _check_multiple_access_to_resources(
      self, captured_resources, exclusive_resource_access):
    """Raise if captured_resources are accessed by another CriticalSection.

    Args:
      captured_resources: Set of tensors of type resource.
      exclusive_resource_access: Whether this execution requires exclusive
        resource access.

    Raises:
      ValueError: If any tensors in `captured_resources` are also accessed
        by another `CriticalSection`, and at least one of them requires
        exclusive resource access.
    """
    # Collections and op introspection does not work in eager
    # mode.  This is generally ok; since eager mode (as of
    # writing) executes sequentially anyway.
    for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
      if self._is_self_handle(sg.handle):
        # Other executions in the same critical section are allowed.
        continue
      if not (exclusive_resource_access or sg.exclusive_resource_access):
        # Neither execution requested exclusive access.
        continue
      resource_intersection = captured_resources.intersection(sg.resources)
      if resource_intersection:
        raise ValueError(
            "This execution would access resources: %s.  Either this "
            "lock (CriticalSection: %s) or lock '%s' "
            "(CriticalSection: %s) requested exclusive resource access "
            "of this resource.  Did you mean to call execute with keyword "
            "argument exclusive_resource_access=False?" %
            (list(resource_intersection), self._handle.name,
             sg.op.name, sg.handle.name))

  # TODO(ebrevdo): Re-enable once CriticalSection is in core.

  # def to_proto(self, export_scope=None):
  #   """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.

  #   Args:
  #     export_scope: Optional `string`. Name scope to remove.

  #   Returns:
  #     A `CriticalSectionDef` protocol buffer, or `None` if the
  #     `CriticalSection` is not in the specified name scope.
  #   """
  #   if export_scope is None or self.handle.name.startswith(export_scope):
  #     cs_def = critical_section_pb2.CriticalSectionDef()
  #     cs_def.critical_section_name = ops.strip_name_scope(
  #         self._handle.name, export_scope)
  #     return cs_def
  #   else:
  #     return None

  # @staticmethod
  # def from_proto(critical_section_def, import_scope=None):
  #   return CriticalSection(
  #       critical_section_def=critical_section_def, import_scope=import_scope)


# TODO(ebrevdo): Re-enable once CriticalSection is in core.

# def _execution_to_proto_fn(execution_signature, export_scope=None):
#   """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
#   # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.

#   Args:
#     execution_signature: Instance of `_ExecutionSignature`.
#     export_scope: The export scope, if any.

#   Returns:
#     An instance of `CriticalSectionExecutionDef`.
#   """
#   if (export_scope is None
#       or execution_signature.op.name.startswith(export_scope)):
#     op_def = critical_section_pb2.CriticalSectionExecutionDef()
#     op_def.execute_in_critical_section_name = ops.strip_name_scope(
#         execution_signature.op.name, export_scope)
#     op_def.exclusive_resource_access = (
#         execution_signature.exclusive_resource_access)
#     return op_def
#   else:
#     return None


# def _execution_from_proto_fn(op_def, import_scope=None):
#   """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
#   # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
#   assert isinstance(
#       op_def, critical_section_pb2.CriticalSectionExecutionDef)

#   # Create from op_def.
#   g = ops.get_default_graph()
#   execution_op = g.as_graph_element(
#       ops.prepend_name_scope(
#           op_def.execute_in_critical_section_name,
#           import_scope=import_scope))
#   return _ExecutionSignature(
#       op=execution_op,
#       exclusive_resource_access=op_def.exclusive_resource_access)

# ops.register_proto_function(
#     CRITICAL_SECTIONS,
#     proto_type=critical_section_pb2.CriticalSectionDef,
#     to_proto=CriticalSection.to_proto,
#     from_proto=CriticalSection.from_proto)

# ops.register_proto_function(
#     CRITICAL_SECTION_EXECUTIONS,
#     proto_type=critical_section_pb2.CriticalSectionExecutionDef,
#     to_proto=_execution_to_proto_fn,
#     from_proto=_execution_from_proto_fn)