aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/queue_runner_impl.py
blob: ac9d4c850d0c143a70ddc645d0a7a332930cc6b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Create threads to run multiple enqueue ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import threading
import weakref

from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export

_DEPRECATION_INSTRUCTION = (
    "To construct input pipelines, use the `tf.data` module.")


@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"])
class QueueRunner(object):
  """Holds a list of enqueue operations for a queue, each to be run in a thread.

  Queues are a convenient TensorFlow mechanism to compute tensors
  asynchronously using multiple threads. For example in the canonical 'Input
  Reader' setup one set of threads generates filenames in a queue; a second set
  of threads read records from the files, processes them, and enqueues tensors
  on a second queue; a third set of threads dequeues these input records to
  construct batches and runs them through training operations.

  There are several delicate issues when running multiple threads that way:
  closing the queues in sequence as the input is exhausted, correctly catching
  and reporting exceptions, etc.

  The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.

  @compatibility(eager)
  QueueRunners are not compatible with eager execution. Instead, please
  use `tf.data` to get data into your model.
  @end_compatibility
  """

  @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
  def __init__(self, queue=None, enqueue_ops=None, close_op=None,
               cancel_op=None, queue_closed_exception_types=None,
               queue_runner_def=None, import_scope=None):
    """Create a QueueRunner.

    On construction the `QueueRunner` adds an op to close the queue.  That op
    will be run if the enqueue ops raise exceptions.

    When you later call the `create_threads()` method, the `QueueRunner` will
    create one thread for each op in `enqueue_ops`.  Each thread will run its
    enqueue op in parallel with the other threads.  The enqueue ops do not have
    to all be the same op, but it is expected that they all enqueue tensors in
    `queue`.

    Args:
      queue: A `Queue`.
      enqueue_ops: List of enqueue ops to run in threads later.
      close_op: Op to close the queue. Pending enqueue ops are preserved.
      cancel_op: Op to close the queue and cancel pending enqueue ops.
      queue_closed_exception_types: Optional tuple of Exception types that
        indicate that the queue has been closed when raised during an enqueue
        operation.  Defaults to `(tf.errors.OutOfRangeError,)`.  Another common
        case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
        when some of the enqueue ops may dequeue from other Queues.
      queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
        recreates the QueueRunner from its contents. `queue_runner_def` and the
        other arguments are mutually exclusive.
      import_scope: Optional `string`. Name scope to add. Only used when
        initializing from protocol buffer.

    Raises:
      ValueError: If both `queue_runner_def` and `queue` are both specified.
      ValueError: If `queue` or `enqueue_ops` are not provided when not
        restoring from `queue_runner_def`.
      RuntimeError: If eager execution is enabled.
    """
    if context.executing_eagerly():
      raise RuntimeError(
          "QueueRunners are not supported when eager execution is enabled. "
          "Instead, please use tf.data to get data into your model.")

    if queue_runner_def:
      if queue or enqueue_ops:
        raise ValueError("queue_runner_def and queue are mutually exclusive.")
      self._init_from_proto(queue_runner_def,
                            import_scope=import_scope)
    else:
      self._init_from_args(
          queue=queue, enqueue_ops=enqueue_ops,
          close_op=close_op, cancel_op=cancel_op,
          queue_closed_exception_types=queue_closed_exception_types)
    # Protect the count of runs to wait for.
    self._lock = threading.Lock()
    # A map from a session object to the number of outstanding queue runner
    # threads for that session.
    self._runs_per_session = weakref.WeakKeyDictionary()
    # List of exceptions raised by the running threads.
    self._exceptions_raised = []

  def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
                      cancel_op=None, queue_closed_exception_types=None):
    """Create a QueueRunner from arguments.

    Args:
      queue: A `Queue`.
      enqueue_ops: List of enqueue ops to run in threads later.
      close_op: Op to close the queue. Pending enqueue ops are preserved.
      cancel_op: Op to close the queue and cancel pending enqueue ops.
      queue_closed_exception_types: Tuple of exception types, which indicate
        the queue has been safely closed.

    Raises:
      ValueError: If `queue` or `enqueue_ops` are not provided when not
        restoring from `queue_runner_def`.
      TypeError: If `queue_closed_exception_types` is provided, but is not
        a non-empty tuple of error types (subclasses of `tf.errors.OpError`).
    """
    if not queue or not enqueue_ops:
      raise ValueError("Must provide queue and enqueue_ops.")
    self._queue = queue
    self._enqueue_ops = enqueue_ops
    self._close_op = close_op
    self._cancel_op = cancel_op
    if queue_closed_exception_types is not None:
      if (not isinstance(queue_closed_exception_types, tuple)
          or not queue_closed_exception_types
          or not all(issubclass(t, errors.OpError)
                     for t in queue_closed_exception_types)):
        raise TypeError(
            "queue_closed_exception_types, when provided, "
            "must be a tuple of tf.error types, but saw: %s"
            % queue_closed_exception_types)
    self._queue_closed_exception_types = queue_closed_exception_types
    # Close when no more will be produced, but pending enqueues should be
    # preserved.
    if self._close_op is None:
      self._close_op = self._queue.close()
    # Close and cancel pending enqueues since there was an error and we want
    # to unblock everything so we can cleanly exit.
    if self._cancel_op is None:
      self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
    if not self._queue_closed_exception_types:
      self._queue_closed_exception_types = (errors.OutOfRangeError,)
    else:
      self._queue_closed_exception_types = tuple(
          self._queue_closed_exception_types)

  def _init_from_proto(self, queue_runner_def, import_scope=None):
    """Create a QueueRunner from `QueueRunnerDef`.

    Args:
      queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
    assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
    g = ops.get_default_graph()
    self._queue = g.as_graph_element(
        ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
    self._enqueue_ops = [g.as_graph_element(
        ops.prepend_name_scope(op, import_scope))
                         for op in queue_runner_def.enqueue_op_name]
    self._close_op = g.as_graph_element(ops.prepend_name_scope(
        queue_runner_def.close_op_name, import_scope))
    self._cancel_op = g.as_graph_element(ops.prepend_name_scope(
        queue_runner_def.cancel_op_name, import_scope))
    self._queue_closed_exception_types = tuple(
        errors.exception_type_from_error_code(code)
        for code in queue_runner_def.queue_closed_exception_types)
    # Legacy support for old QueueRunnerDefs created before this field
    # was added.
    if not self._queue_closed_exception_types:
      self._queue_closed_exception_types = (errors.OutOfRangeError,)

  @property
  def queue(self):
    return self._queue

  @property
  def enqueue_ops(self):
    return self._enqueue_ops

  @property
  def close_op(self):
    return self._close_op

  @property
  def cancel_op(self):
    return self._cancel_op

  @property
  def queue_closed_exception_types(self):
    return self._queue_closed_exception_types

  @property
  def exceptions_raised(self):
    """Exceptions raised but not handled by the `QueueRunner` threads.

    Exceptions raised in queue runner threads are handled in one of two ways
    depending on whether or not a `Coordinator` was passed to
    `create_threads()`:

    * With a `Coordinator`, exceptions are reported to the coordinator and
      forgotten by the `QueueRunner`.
    * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
      made available in this `exceptions_raised` property.

    Returns:
      A list of Python `Exception` objects.  The list is empty if no exception
      was captured.  (No exceptions are captured when using a Coordinator.)
    """
    return self._exceptions_raised

  @property
  def name(self):
    """The string name of the underlying Queue."""
    return self._queue.name

  # pylint: disable=broad-except
  def _run(self, sess, enqueue_op, coord=None):
    """Execute the enqueue op in a loop, close the queue in case of error.

    Args:
      sess: A Session.
      enqueue_op: The Operation to run.
      coord: Optional Coordinator object for reporting errors and checking
        for stop conditions.
    """
    decremented = False
    try:
      # Make a cached callable from the `enqueue_op` to decrease the
      # Python overhead in the queue-runner loop.
      enqueue_callable = sess.make_callable(enqueue_op)
      while True:
        if coord and coord.should_stop():
          break
        try:
          enqueue_callable()
        except self._queue_closed_exception_types:  # pylint: disable=catching-non-exception
          # This exception indicates that a queue was closed.
          with self._lock:
            self._runs_per_session[sess] -= 1
            decremented = True
            if self._runs_per_session[sess] == 0:
              try:
                sess.run(self._close_op)
              except Exception as e:
                # Intentionally ignore errors from close_op.
                logging.vlog(1, "Ignored exception: %s", str(e))
            return
    except Exception as e:
      # This catches all other exceptions.
      if coord:
        coord.request_stop(e)
      else:
        logging.error("Exception in QueueRunner: %s", str(e))
        with self._lock:
          self._exceptions_raised.append(e)
        raise
    finally:
      # Make sure we account for all terminations: normal or errors.
      if not decremented:
        with self._lock:
          self._runs_per_session[sess] -= 1

  def _close_on_stop(self, sess, cancel_op, coord):
    """Close the queue when the Coordinator requests stop.

    Args:
      sess: A Session.
      cancel_op: The Operation to run.
      coord: Coordinator.
    """
    coord.wait_for_stop()
    try:
      sess.run(cancel_op)
    except Exception as e:
      # Intentionally ignore errors from cancel_op.
      logging.vlog(1, "Ignored exception: %s", str(e))
  # pylint: enable=broad-except

  def create_threads(self, sess, coord=None, daemon=False, start=False):
    """Create threads to run the enqueue ops for the given session.

    This method requires a session in which the graph was launched.  It creates
    a list of threads, optionally starting them.  There is one thread for each
    op passed in `enqueue_ops`.

    The `coord` argument is an optional coordinator that the threads will use
    to terminate together and report exceptions.  If a coordinator is given,
    this method starts an additional thread to close the queue when the
    coordinator requests a stop.

    If previously created threads for the given session are still running, no
    new threads will be created.

    Args:
      sess: A `Session`.
      coord: Optional `Coordinator` object for reporting errors and checking
        stop conditions.
      daemon: Boolean.  If `True` make the threads daemon threads.
      start: Boolean.  If `True` starts the threads.  If `False` the
        caller must call the `start()` method of the returned threads.

    Returns:
      A list of threads.
    """
    with self._lock:
      try:
        if self._runs_per_session[sess] > 0:
          # Already started: no new threads to return.
          return []
      except KeyError:
        # We haven't seen this session yet.
        pass
      self._runs_per_session[sess] = len(self._enqueue_ops)
      self._exceptions_raised = []

    ret_threads = []
    for op in self._enqueue_ops:
      name = "QueueRunnerThread-{}-{}".format(self.name, op.name)
      ret_threads.append(threading.Thread(target=self._run,
                                          args=(sess, op, coord),
                                          name=name))
    if coord:
      name = "QueueRunnerThread-{}-close_on_stop".format(self.name)
      ret_threads.append(threading.Thread(target=self._close_on_stop,
                                          args=(sess, self._cancel_op, coord),
                                          name=name))
    for t in ret_threads:
      if coord:
        coord.register_thread(t)
      if daemon:
        t.daemon = True
      if start:
        t.start()
    return ret_threads

  def to_proto(self, export_scope=None):
    """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
      the specified name scope.
    """
    if (export_scope is None or
        self.queue.name.startswith(export_scope)):
      queue_runner_def = queue_runner_pb2.QueueRunnerDef()
      queue_runner_def.queue_name = ops.strip_name_scope(
          self.queue.name, export_scope)
      for enqueue_op in self.enqueue_ops:
        queue_runner_def.enqueue_op_name.append(
            ops.strip_name_scope(enqueue_op.name, export_scope))
      queue_runner_def.close_op_name = ops.strip_name_scope(
          self.close_op.name, export_scope)
      queue_runner_def.cancel_op_name = ops.strip_name_scope(
          self.cancel_op.name, export_scope)
      queue_runner_def.queue_closed_exception_types.extend([
          errors.error_code_from_exception_type(cls)
          for cls in self._queue_closed_exception_types])
      return queue_runner_def
    else:
      return None

  @staticmethod
  def from_proto(queue_runner_def, import_scope=None):
    """Returns a `QueueRunner` object created from `queue_runner_def`."""
    return QueueRunner(queue_runner_def=queue_runner_def,
                       import_scope=import_scope)


@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"])
@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
  """Adds a `QueueRunner` to a collection in the graph.

  When building a complex model that uses many queues it is often difficult to
  gather all the queue runners that need to be run.  This convenience function
  allows you to add a queue runner to a well known collection in the graph.

  The companion method `start_queue_runners()` can be used to start threads for
  all the collected queue runners.

  Args:
    qr: A `QueueRunner`.
    collection: A `GraphKey` specifying the graph collection to add
      the queue runner to.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
  """
  ops.add_to_collection(collection, qr)


@tf_export(v1=["train.queue_runner.start_queue_runners",
               "train.start_queue_runners"])
@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
                        collection=ops.GraphKeys.QUEUE_RUNNERS):
  """Starts all queue runners collected in the graph.

  This is a companion method to `add_queue_runner()`.  It just starts
  threads for all queue runners collected in the graph.  It returns
  the list of all threads.

  Args:
    sess: `Session` used to run the queue ops.  Defaults to the
      default session.
    coord: Optional `Coordinator` for coordinating the started threads.
    daemon: Whether the threads should be marked as `daemons`, meaning
      they don't block program exit.
    start: Set to `False` to only create the threads, not start them.
    collection: A `GraphKey` specifying the graph collection to
      get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.

  Raises:
    ValueError: if `sess` is None and there isn't any default session.
    TypeError: if `sess` is not a `tf.Session` object.

  Returns:
    A list of threads.

  Raises:
    RuntimeError: If called with eager execution enabled.
    ValueError: If called without a default `tf.Session` registered.

  @compatibility(eager)
  Not compatible with eager execution. To ingest data under eager execution,
  use the `tf.data` API instead.
  @end_compatibility
  """
  if context.executing_eagerly():
    raise RuntimeError("Queues are not compatible with eager execution.")
  if sess is None:
    sess = ops.get_default_session()
    if not sess:
      raise ValueError("Cannot start queue runners: No default session is "
                       "registered. Use `with sess.as_default()` or pass an "
                       "explicit session to tf.start_queue_runners(sess=sess)")

  if not isinstance(sess, session.SessionInterface):
    # Following check is due to backward compatibility. (b/62061352)
    if sess.__class__.__name__ in [
        "MonitoredSession", "SingularMonitoredSession"]:
      return []
    raise TypeError("sess must be a `tf.Session` object. "
                    "Given class: {}".format(sess.__class__))

  queue_runners = ops.get_collection(collection)
  if not queue_runners:
    logging.warning(
        "`tf.train.start_queue_runners()` was called when no queue runners "
        "were defined. You can safely remove the call to this deprecated "
        "function.")

  with sess.graph.as_default():
    threads = []
    for qr in ops.get_collection(collection):
      threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
                                       start=start))
  return threads


ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS,
                            proto_type=queue_runner_pb2.QueueRunnerDef,
                            to_proto=QueueRunner.to_proto,
                            from_proto=QueueRunner.from_proto)