aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/queue_runner_impl.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-01 12:43:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-01 12:46:41 -0700
commit7ad0d0698ab443324bbe68dd5d6476111c6b229a (patch)
tree7bd46a3bd1954db524f757d700d8a62fab876e90 /tensorflow/python/training/queue_runner_impl.py
parent7106f9fac32c61af59285e6ccb0b9c623a8334c3 (diff)
Add type error to start_queue_runners if given session is not a `tf.Session`. Due to semver, we suppress the error if a MonitoredSession is provided.
PiperOrigin-RevId: 157748375
Diffstat (limited to 'tensorflow/python/training/queue_runner_impl.py')
-rw-r--r--tensorflow/python/training/queue_runner_impl.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py
index d713e222ae..4e58602a6f 100644
--- a/tensorflow/python/training/queue_runner_impl.py
+++ b/tensorflow/python/training/queue_runner_impl.py
@@ -22,6 +22,7 @@ import threading
import weakref
from tensorflow.core.protobuf import queue_runner_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
@@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection: A `GraphKey` specifying the graph collection to
get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
+ Raises:
+ ValueError: if `sess` is None and there isn't any default session.
+ TypeError: if `sess` is not a `tf.Session` object.
+
Returns:
A list of threads.
"""
@@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
raise ValueError("Cannot start queue runners: No default session is "
"registered. Use `with sess.as_default()` or pass an "
"explicit session to tf.start_queue_runners(sess=sess)")
+
+ if not isinstance(sess, session.SessionInterface):
+ # Following check is due to backward compatibility. (b/62061352)
+ if sess.__class__.__name__ in [
+ "MonitoredSession", "SingularMonitoredSession"]:
+ return []
+ raise TypeError("sess must be a `tf.Session` object. "
+ "Given class: {}".format(sess.__class__))
+
with sess.graph.as_default():
threads = []
for qr in ops.get_collection(collection):