aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-01 12:43:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-01 12:46:41 -0700
commit7ad0d0698ab443324bbe68dd5d6476111c6b229a (patch)
tree7bd46a3bd1954db524f757d700d8a62fab876e90 /tensorflow/python/training
parent7106f9fac32c61af59285e6ccb0b9c623a8334c3 (diff)
Add type error to start_queue_runners if given session is not a `tf.Session`. Due to semver, we suppress the error if a MonitoredSession is provided.
PiperOrigin-RevId: 157748375
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/queue_runner_impl.py14
-rw-r--r--tensorflow/python/training/queue_runner_test.py28
2 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/python/training/queue_runner_impl.py b/tensorflow/python/training/queue_runner_impl.py
index d713e222ae..4e58602a6f 100644
--- a/tensorflow/python/training/queue_runner_impl.py
+++ b/tensorflow/python/training/queue_runner_impl.py
@@ -22,6 +22,7 @@ import threading
import weakref
from tensorflow.core.protobuf import queue_runner_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
@@ -401,6 +402,10 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
collection: A `GraphKey` specifying the graph collection to
get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
+ Raises:
+ ValueError: if `sess` is None and there isn't any default session.
+ TypeError: if `sess` is not a `tf.Session` object.
+
Returns:
A list of threads.
"""
@@ -410,6 +415,15 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
raise ValueError("Cannot start queue runners: No default session is "
"registered. Use `with sess.as_default()` or pass an "
"explicit session to tf.start_queue_runners(sess=sess)")
+
+ if not isinstance(sess, session.SessionInterface):
+ # Following check is due to backward compatibility. (b/62061352)
+ if sess.__class__.__name__ in [
+ "MonitoredSession", "SingularMonitoredSession"]:
+ return []
+ raise TypeError("sess must be a `tf.Session` object. "
+ "Given class: {}".format(sess.__class__))
+
with sess.graph.as_default():
threads = []
for qr in ops.get_collection(collection):
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 5b00ac9fc3..51c0eecf46 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner_impl
@@ -247,6 +248,33 @@ class QueueRunnerTest(test.TestCase):
# The variable should be 3.
self.assertEqual(3, var.eval())
+ def testStartQueueRunnersRaisesIfNotASession(self):
+ zero64 = constant_op.constant(0, dtype=dtypes.int64)
+ var = variables.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
+ init_op = variables.global_variables_initializer()
+ qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
+ queue_runner_impl.add_queue_runner(qr)
+ with self.test_session():
+ init_op.run()
+ with self.assertRaisesRegexp(TypeError, "tf.Session"):
+ queue_runner_impl.start_queue_runners("NotASession")
+
+ def testStartQueueRunnersIgnoresMonitoredSession(self):
+ zero64 = constant_op.constant(0, dtype=dtypes.int64)
+ var = variables.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
+ init_op = variables.global_variables_initializer()
+ qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
+ queue_runner_impl.add_queue_runner(qr)
+ with self.test_session():
+ init_op.run()
+ threads = queue_runner_impl.start_queue_runners(
+ monitored_session.MonitoredSession())
+ self.assertFalse(threads)
+
def testStartQueueRunnersNonDefaultGraph(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
graph = ops.Graph()