diff options
author | 2016-10-31 06:49:11 -0800 | |
---|---|---|
committer | 2016-10-31 08:02:57 -0700 | |
commit | 5ad6738c5117ebc2b9384a379a38fa0fccd587a0 (patch) | |
tree | 7a1a3446a9e0ff17d9f57f73852945f2197cfd47 /tensorflow/python/training/queue_runner.py | |
parent | 57f42975a1c02ae35ce6d56bda0603ef24894230 (diff) |
Allow a QueueRunner to create_threads on multiple sessions.
Change: 137701036
Diffstat (limited to 'tensorflow/python/training/queue_runner.py')
-rw-r--r-- | tensorflow/python/training/queue_runner.py | 35 |
1 files changed, 19 insertions, 16 deletions
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py index ff77437c02..fa8964f69f 100644 --- a/tensorflow/python/training/queue_runner.py +++ b/tensorflow/python/training/queue_runner.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import threading +import weakref from tensorflow.core.protobuf import queue_runner_pb2 from tensorflow.python.framework import errors @@ -90,7 +91,9 @@ class QueueRunner(object): queue_closed_exception_types=queue_closed_exception_types) # Protect the count of runs to wait for. self._lock = threading.Lock() - self._runs = 0 + # A map from a session object to the number of outstanding queue runner + # threads for that session. + self._runs_per_session = weakref.WeakKeyDictionary() # List of exceptions raised by the running threads. self._exceptions_raised = [] @@ -234,9 +237,9 @@ class QueueRunner(object): except self._queue_closed_exception_types: # pylint: disable=catching-non-exception # This exception indicates that a queue was closed. with self._lock: - self._runs -= 1 + self._runs_per_session[sess] -= 1 decremented = True - if self._runs == 0: + if self._runs_per_session[sess] == 0: try: sess.run(self._close_op) except Exception as e: @@ -256,7 +259,7 @@ class QueueRunner(object): # Make sure we account for all terminations: normal or errors. if not decremented: with self._lock: - self._runs -= 1 + self._runs_per_session[sess] -= 1 def _close_on_stop(self, sess, cancel_op, coord): """Close the queue when the Coordinator requests stop. @@ -276,19 +279,19 @@ class QueueRunner(object): # pylint: enable=broad-except def create_threads(self, sess, coord=None, daemon=False, start=False): - """Create threads to run the enqueue ops. + """Create threads to run the enqueue ops for the given session. This method requires a session in which the graph was launched. It creates a list of threads, optionally starting them. There is one thread for each op passed in `enqueue_ops`. - The `coord` argument is an optional coordinator, that the threads will use + The `coord` argument is an optional coordinator that the threads will use to terminate together and report exceptions. If a coordinator is given, this method starts an additional thread to close the queue when the coordinator requests a stop. - This method may be called again as long as all threads from a previous call - have stopped. + If previously created threads for the given session are still running, no + new threads will be created. Args: sess: A `Session`. @@ -300,16 +303,16 @@ class QueueRunner(object): Returns: A list of threads. - - Raises: - RuntimeError: If threads from a previous call to `create_threads()` are - still running. """ with self._lock: - if self._runs > 0: - # Already started: no new threads to return. - return [] - self._runs = len(self._enqueue_ops) + try: + if self._runs_per_session[sess] > 0: + # Already started: no new threads to return. + return [] + except KeyError: + # We haven't seen this session yet. + pass + self._runs_per_session[sess] = len(self._enqueue_ops) self._exceptions_raised = [] ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord)) |