aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/queue_runner.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-31 06:49:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-31 08:02:57 -0700
commit5ad6738c5117ebc2b9384a379a38fa0fccd587a0 (patch)
tree7a1a3446a9e0ff17d9f57f73852945f2197cfd47 /tensorflow/python/training/queue_runner.py
parent57f42975a1c02ae35ce6d56bda0603ef24894230 (diff)
Allow a QueueRunner to create_threads on multiple sessions.
Change: 137701036
Diffstat (limited to 'tensorflow/python/training/queue_runner.py')
-rw-r--r--tensorflow/python/training/queue_runner.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
index ff77437c02..fa8964f69f 100644
--- a/tensorflow/python/training/queue_runner.py
+++ b/tensorflow/python/training/queue_runner.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import threading
+import weakref
from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.python.framework import errors
@@ -90,7 +91,9 @@ class QueueRunner(object):
queue_closed_exception_types=queue_closed_exception_types)
# Protect the count of runs to wait for.
self._lock = threading.Lock()
- self._runs = 0
+ # A map from a session object to the number of outstanding queue runner
+ # threads for that session.
+ self._runs_per_session = weakref.WeakKeyDictionary()
# List of exceptions raised by the running threads.
self._exceptions_raised = []
@@ -234,9 +237,9 @@ class QueueRunner(object):
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
# This exception indicates that a queue was closed.
with self._lock:
- self._runs -= 1
+ self._runs_per_session[sess] -= 1
decremented = True
- if self._runs == 0:
+ if self._runs_per_session[sess] == 0:
try:
sess.run(self._close_op)
except Exception as e:
@@ -256,7 +259,7 @@ class QueueRunner(object):
# Make sure we account for all terminations: normal or errors.
if not decremented:
with self._lock:
- self._runs -= 1
+ self._runs_per_session[sess] -= 1
def _close_on_stop(self, sess, cancel_op, coord):
"""Close the queue when the Coordinator requests stop.
@@ -276,19 +279,19 @@ class QueueRunner(object):
# pylint: enable=broad-except
def create_threads(self, sess, coord=None, daemon=False, start=False):
- """Create threads to run the enqueue ops.
+ """Create threads to run the enqueue ops for the given session.
This method requires a session in which the graph was launched. It creates
a list of threads, optionally starting them. There is one thread for each
op passed in `enqueue_ops`.
- The `coord` argument is an optional coordinator, that the threads will use
+ The `coord` argument is an optional coordinator that the threads will use
to terminate together and report exceptions. If a coordinator is given,
this method starts an additional thread to close the queue when the
coordinator requests a stop.
- This method may be called again as long as all threads from a previous call
- have stopped.
+ If previously created threads for the given session are still running, no
+ new threads will be created.
Args:
sess: A `Session`.
@@ -300,16 +303,16 @@ class QueueRunner(object):
Returns:
A list of threads.
-
- Raises:
- RuntimeError: If threads from a previous call to `create_threads()` are
- still running.
"""
with self._lock:
- if self._runs > 0:
- # Already started: no new threads to return.
- return []
- self._runs = len(self._enqueue_ops)
+ try:
+ if self._runs_per_session[sess] > 0:
+ # Already started: no new threads to return.
+ return []
+ except KeyError:
+ # We haven't seen this session yet.
+ pass
+ self._runs_per_session[sess] = len(self._enqueue_ops)
self._exceptions_raised = []
ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord))