diff options
author | Patrick Nguyen <drpng@google.com> | 2016-11-08 14:52:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-08 16:32:24 -0800 |
commit | e8a129f4f974210ffd6201e5e5f4f2fdd29e79da (patch) | |
tree | 64c0ae4729ee916df0a1635c9b99da35f92e5f82 /tensorflow/python/training/queue_runner_impl.py | |
parent | f905c2561e07677791b00cd6f17fb12e9d407da8 (diff) |
Seal queue_runner's interface.
Change: 138568627
Diffstat (limited to 'tensorflow/python/training/queue_runner_impl.py')
-rw-r--r-- | tensorflow/python/training/queue_runner_impl.py | 421 |
1 files changed, 421 insertions, 0 deletions
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py new file mode 100644 index 0000000000..91999fa37d --- /dev/null +++ b/tensorflow/python/training/queue_runner_impl.py @@ -0,0 +1,421 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Create threads to run multiple enqueue ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import weakref + +from tensorflow.core.protobuf import queue_runner_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import tf_logging as logging + + +class QueueRunner(object): + """Holds a list of enqueue operations for a queue, each to be run in a thread. + + Queues are a convenient TensorFlow mechanism to compute tensors + asynchronously using multiple threads. For example in the canonical 'Input + Reader' setup one set of threads generates filenames in a queue; a second set + of threads read records from the files, processes them, and enqueues tensors + on a second queue; a third set of threads dequeues these input records to + construct batches and runs them through training operations. + + There are several delicate issues when running multiple threads that way: + closing the queues in sequence as the input is exhausted, correctly catching + and reporting exceptions, etc. + + The `QueueRunner`, combined with the `Coordinator`, helps handle these issues. + """ + + def __init__(self, queue=None, enqueue_ops=None, close_op=None, + cancel_op=None, queue_closed_exception_types=None, + queue_runner_def=None, import_scope=None): + """Create a QueueRunner. + + On construction the `QueueRunner` adds an op to close the queue. That op + will be run if the enqueue ops raise exceptions. + + When you later call the `create_threads()` method, the `QueueRunner` will + create one thread for each op in `enqueue_ops`. Each thread will run its + enqueue op in parallel with the other threads. The enqueue ops do not have + to all be the same op, but it is expected that they all enqueue tensors in + `queue`. + + Args: + queue: A `Queue`. + enqueue_ops: List of enqueue ops to run in threads later. + close_op: Op to close the queue. Pending enqueue ops are preserved. + cancel_op: Op to close the queue and cancel pending enqueue ops. + queue_closed_exception_types: Optional tuple of Exception types that + indicate that the queue has been closed when raised during an enqueue + operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common + case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`, + when some of the enqueue ops may dequeue from other Queues. + queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified, + recreates the QueueRunner from its contents. `queue_runner_def` and the + other arguments are mutually exclusive. + import_scope: Optional `string`. Name scope to add. Only used when + initializing from protocol buffer. + + Raises: + ValueError: If both `queue_runner_def` and `queue` are both specified. + ValueError: If `queue` or `enqueue_ops` are not provided when not + restoring from `queue_runner_def`. + """ + if queue_runner_def: + if queue or enqueue_ops: + raise ValueError("queue_runner_def and queue are mutually exclusive.") + self._init_from_proto(queue_runner_def, + import_scope=import_scope) + else: + self._init_from_args( + queue=queue, enqueue_ops=enqueue_ops, + close_op=close_op, cancel_op=cancel_op, + queue_closed_exception_types=queue_closed_exception_types) + # Protect the count of runs to wait for. + self._lock = threading.Lock() + # A map from a session object to the number of outstanding queue runner + # threads for that session. + self._runs_per_session = weakref.WeakKeyDictionary() + # List of exceptions raised by the running threads. + self._exceptions_raised = [] + + def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, + cancel_op=None, queue_closed_exception_types=None): + """Create a QueueRunner from arguments. + + Args: + queue: A `Queue`. + enqueue_ops: List of enqueue ops to run in threads later. + close_op: Op to close the queue. Pending enqueue ops are preserved. + cancel_op: Op to close the queue and cancel pending enqueue ops. + queue_closed_exception_types: Tuple of exception types, which indicate + the queue has been safely closed. + + Raises: + ValueError: If `queue` or `enqueue_ops` are not provided when not + restoring from `queue_runner_def`. + TypeError: If `queue_closed_exception_types` is provided, but is not + a non-empty tuple of error types (subclasses of `tf.errors.OpError`). + """ + if not queue or not enqueue_ops: + raise ValueError("Must provide queue and enqueue_ops.") + self._queue = queue + self._enqueue_ops = enqueue_ops + self._close_op = close_op + self._cancel_op = cancel_op + if queue_closed_exception_types is not None: + if (not isinstance(queue_closed_exception_types, tuple) + or not queue_closed_exception_types + or not all(issubclass(t, errors.OpError) + for t in queue_closed_exception_types)): + raise TypeError( + "queue_closed_exception_types, when provided, " + "must be a non-empty list of tf.error types, but saw: %s" + % queue_closed_exception_types) + self._queue_closed_exception_types = queue_closed_exception_types + # Close when no more will be produced, but pending enqueues should be + # preserved. + if self._close_op is None: + self._close_op = self._queue.close() + # Close and cancel pending enqueues since there was an error and we want + # to unblock everything so we can cleanly exit. + if self._cancel_op is None: + self._cancel_op = self._queue.close(cancel_pending_enqueues=True) + if not self._queue_closed_exception_types: + self._queue_closed_exception_types = (errors.OutOfRangeError,) + else: + self._queue_closed_exception_types = tuple( + self._queue_closed_exception_types) + + def _init_from_proto(self, queue_runner_def, import_scope=None): + """Create a QueueRunner from `QueueRunnerDef`. + + Args: + queue_runner_def: Optional `QueueRunnerDef` protocol buffer. + import_scope: Optional `string`. Name scope to add. + """ + assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) + g = ops.get_default_graph() + self._queue = g.as_graph_element( + ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) + self._enqueue_ops = [g.as_graph_element( + ops.prepend_name_scope(op, import_scope)) + for op in queue_runner_def.enqueue_op_name] + self._close_op = g.as_graph_element(ops.prepend_name_scope( + queue_runner_def.close_op_name, import_scope)) + self._cancel_op = g.as_graph_element(ops.prepend_name_scope( + queue_runner_def.cancel_op_name, import_scope)) + self._queue_closed_exception_types = tuple( + errors.exception_type_from_error_code(code) + for code in queue_runner_def.queue_closed_exception_types) + # Legacy support for old QueueRunnerDefs created before this field + # was added. + if not self._queue_closed_exception_types: + self._queue_closed_exception_types = (errors.OutOfRangeError,) + + @property + def queue(self): + return self._queue + + @property + def enqueue_ops(self): + return self._enqueue_ops + + @property + def close_op(self): + return self._close_op + + @property + def cancel_op(self): + return self._cancel_op + + @property + def queue_closed_exception_types(self): + return self._queue_closed_exception_types + + @property + def exceptions_raised(self): + """Exceptions raised but not handled by the `QueueRunner` threads. + + Exceptions raised in queue runner threads are handled in one of two ways + depending on whether or not a `Coordinator` was passed to + `create_threads()`: + + * With a `Coordinator`, exceptions are reported to the coordinator and + forgotten by the `QueueRunner`. + * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and + made available in this `exceptions_raised` property. + + Returns: + A list of Python `Exception` objects. The list is empty if no exception + was captured. (No exceptions are captured when using a Coordinator.) + """ + return self._exceptions_raised + + @property + def name(self): + """The string name of the underlying Queue.""" + return self._queue.name + + # pylint: disable=broad-except + def _run(self, sess, enqueue_op, coord=None): + """Execute the enqueue op in a loop, close the queue in case of error. + + Args: + sess: A Session. + enqueue_op: The Operation to run. + coord: Optional Coordinator object for reporting errors and checking + for stop conditions. + """ + decremented = False + try: + while True: + if coord and coord.should_stop(): + break + try: + sess.run(enqueue_op) + except self._queue_closed_exception_types: # pylint: disable=catching-non-exception + # This exception indicates that a queue was closed. + with self._lock: + self._runs_per_session[sess] -= 1 + decremented = True + if self._runs_per_session[sess] == 0: + try: + sess.run(self._close_op) + except Exception as e: + # Intentionally ignore errors from close_op. + logging.vlog(1, "Ignored exception: %s", str(e)) + return + except Exception as e: + # This catches all other exceptions. + if coord: + coord.request_stop(e) + else: + logging.error("Exception in QueueRunner: %s", str(e)) + with self._lock: + self._exceptions_raised.append(e) + raise + finally: + # Make sure we account for all terminations: normal or errors. + if not decremented: + with self._lock: + self._runs_per_session[sess] -= 1 + + def _close_on_stop(self, sess, cancel_op, coord): + """Close the queue when the Coordinator requests stop. + + Args: + sess: A Session. + cancel_op: The Operation to run. + coord: Coordinator. + """ + coord.wait_for_stop() + try: + sess.run(cancel_op) + except Exception as e: + # Intentionally ignore errors from cancel_op. + logging.vlog(1, "Ignored exception: %s", str(e)) + # pylint: enable=broad-except + + def create_threads(self, sess, coord=None, daemon=False, start=False): + """Create threads to run the enqueue ops for the given session. + + This method requires a session in which the graph was launched. It creates + a list of threads, optionally starting them. There is one thread for each + op passed in `enqueue_ops`. + + The `coord` argument is an optional coordinator that the threads will use + to terminate together and report exceptions. If a coordinator is given, + this method starts an additional thread to close the queue when the + coordinator requests a stop. + + If previously created threads for the given session are still running, no + new threads will be created. + + Args: + sess: A `Session`. + coord: Optional `Coordinator` object for reporting errors and checking + stop conditions. + daemon: Boolean. If `True` make the threads daemon threads. + start: Boolean. If `True` starts the threads. If `False` the + caller must call the `start()` method of the returned threads. + + Returns: + A list of threads. + """ + with self._lock: + try: + if self._runs_per_session[sess] > 0: + # Already started: no new threads to return. + return [] + except KeyError: + # We haven't seen this session yet. + pass + self._runs_per_session[sess] = len(self._enqueue_ops) + self._exceptions_raised = [] + + ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord)) + for op in self._enqueue_ops] + if coord: + ret_threads.append(threading.Thread(target=self._close_on_stop, + args=(sess, self._cancel_op, coord))) + for t in ret_threads: + if coord: + coord.register_thread(t) + if daemon: + t.daemon = True + if start: + t.start() + return ret_threads + + def to_proto(self, export_scope=None): + """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. + + Args: + export_scope: Optional `string`. Name scope to remove. + + Returns: + A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in + the specified name scope. + """ + if (export_scope is None or + self.queue.name.startswith(export_scope)): + queue_runner_def = queue_runner_pb2.QueueRunnerDef() + queue_runner_def.queue_name = ops.strip_name_scope( + self.queue.name, export_scope) + for enqueue_op in self.enqueue_ops: + queue_runner_def.enqueue_op_name.append( + ops.strip_name_scope(enqueue_op.name, export_scope)) + queue_runner_def.close_op_name = ops.strip_name_scope( + self.close_op.name, export_scope) + queue_runner_def.cancel_op_name = ops.strip_name_scope( + self.cancel_op.name, export_scope) + queue_runner_def.queue_closed_exception_types.extend([ + errors.error_code_from_exception_type(cls) + for cls in self._queue_closed_exception_types]) + return queue_runner_def + else: + return None + + @staticmethod + def from_proto(queue_runner_def, import_scope=None): + """Returns a `QueueRunner` object created from `queue_runner_def`.""" + return QueueRunner(queue_runner_def=queue_runner_def, + import_scope=import_scope) + + +def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): + """Adds a `QueueRunner` to a collection in the graph. + + When building a complex model that uses many queues it is often difficult to + gather all the queue runners that need to be run. This convenience function + allows you to add a queue runner to a well known collection in the graph. + + The companion method `start_queue_runners()` can be used to start threads for + all the collected queue runners. + + Args: + qr: A `QueueRunner`. + collection: A `GraphKey` specifying the graph collection to add + the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`. + """ + ops.add_to_collection(collection, qr) + + +def start_queue_runners(sess=None, coord=None, daemon=True, start=True, + collection=ops.GraphKeys.QUEUE_RUNNERS): + """Starts all queue runners collected in the graph. + + This is a companion method to `add_queue_runner()`. It just starts + threads for all queue runners collected in the graph. It returns + the list of all threads. + + Args: + sess: `Session` used to run the queue ops. Defaults to the + default session. + coord: Optional `Coordinator` for coordinating the started threads. + daemon: Whether the threads should be marked as `daemons`, meaning + they don't block program exit. + start: Set to `False` to only create the threads, not start them. + collection: A `GraphKey` specifying the graph collection to + get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. + + Returns: + A list of threads. + """ + if sess is None: + sess = ops.get_default_session() + if not sess: + raise ValueError("Cannot start queue runners: No default session is " + "registered. Use `with sess.as_default()` or pass an " + "explicit session to tf.start_queue_runners(sess=sess)") + with sess.graph.as_default(): + threads = [] + for qr in ops.get_collection(collection): + threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, + start=start)) + return threads + + +ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS, + proto_type=queue_runner_pb2.QueueRunnerDef, + to_proto=QueueRunner.to_proto, + from_proto=QueueRunner.from_proto) |