diff options
author | 2017-06-01 12:43:25 -0700 | |
---|---|---|
committer | 2017-06-01 12:46:41 -0700 | |
commit | 7ad0d0698ab443324bbe68dd5d6476111c6b229a (patch) | |
tree | 7bd46a3bd1954db524f757d700d8a62fab876e90 /tensorflow/python/training/queue_runner_impl.py | |
parent | 7106f9fac32c61af59285e6ccb0b9c623a8334c3 (diff) |
Add type error to start_queue_runners if given session is not a `tf.Session`. Due to semver, we suppress the error if a MonitoredSession is provided.
PiperOrigin-RevId: 157748375
Diffstat (limited to 'tensorflow/python/training/queue_runner_impl.py')
-rw-r--r-- | tensorflow/python/training/queue_runner_impl.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py index d713e222ae..4e58602a6f 100644 --- a/tensorflow/python/training/queue_runner_impl.py +++ b/tensorflow/python/training/queue_runner_impl.py @@ -22,6 +22,7 @@ import threading import weakref from tensorflow.core.protobuf import queue_runner_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging @@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection: A `GraphKey` specifying the graph collection to get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. + Raises: + ValueError: if `sess` is None and there isn't any default session. + TypeError: if `sess` is not a `tf.Session` object. + Returns: A list of threads. """ @@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True, raise ValueError("Cannot start queue runners: No default session is " "registered. Use `with sess.as_default()` or pass an " "explicit session to tf.start_queue_runners(sess=sess)") + + if not isinstance(sess, session.SessionInterface): + # Following check is due to backward compatibility. (b/62061352) + if sess.__class__.__name__ in [ + "MonitoredSession", "SingularMonitoredSession"]: + return [] + raise TypeError("sess must be a `tf.Session` object. " + "Given class: {}".format(sess.__class__)) + with sess.graph.as_default(): threads = [] for qr in ops.get_collection(collection): |