aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-08-30 20:55:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-30 22:02:18 -0700
commit7e7e0d68ceca42900ddf378a07fff2bec6dec88d (patch)
treec7db31a22d7ae628030d29a781d932016371c7ff
parent4f7a43418513d326e98c7f84fe0835ded55d0189 (diff)
Add new QueueRunner optional argument: queue_closed_exception_types.
* This is a backwards compatible change. * Includes some extra helper functions in the tf.errors module. * Includes extension to the QueueRunner proto * Includes tests that QueueRunner.{from,to}_proto are backwards compatible. Change: 131791267
-rw-r--r--tensorflow/core/protobuf/queue_runner.proto6
-rw-r--r--tensorflow/python/framework/errors.py13
-rw-r--r--tensorflow/python/training/queue_runner.py51
-rw-r--r--tensorflow/python/training/queue_runner_test.py36
4 files changed, 100 insertions, 6 deletions
diff --git a/tensorflow/core/protobuf/queue_runner.proto b/tensorflow/core/protobuf/queue_runner.proto
index 963e12784f..05a48d0acf 100644
--- a/tensorflow/core/protobuf/queue_runner.proto
+++ b/tensorflow/core/protobuf/queue_runner.proto
@@ -6,6 +6,8 @@ option java_outer_classname = "QueueRunnerProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
+import "tensorflow/core/lib/core/error_codes.proto";
+
// Protocol buffer representing a QueueRunner.
message QueueRunnerDef {
// Queue name.
@@ -19,4 +21,8 @@ message QueueRunnerDef {
// The operation to run to cancel the queue.
string cancel_op_name = 4;
+
+ // A list of exception types considered to signal a safely closed queue
+ // if raised during enqueue operations.
+ repeated error.Code queue_closed_exception_types = 5;
}
diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py
index db21f5895c..71c4d97401 100644
--- a/tensorflow/python/framework/errors.py
+++ b/tensorflow/python/framework/errors.py
@@ -428,10 +428,21 @@ _CODE_TO_EXCEPTION_CLASS = {
DATA_LOSS: DataLossError,
}
+_EXCEPTION_CLASS_TO_CODE = dict((
+ (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items()))
+
+
+def exception_type_from_error_code(error_code):
+ return _CODE_TO_EXCEPTION_CLASS[error_code]
+
+
+def error_code_from_exception_type(cls):
+ return _EXCEPTION_CLASS_TO_CODE[cls]
+
def _make_specific_exception(node_def, op, message, error_code):
try:
- exc_type = _CODE_TO_EXCEPTION_CLASS[error_code]
+ exc_type = exception_type_from_error_code(error_code)
return exc_type(node_def, op, message)
except KeyError:
warnings.warn("Unknown error code: %d" % error_code)
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
index 2ae07f0bc5..4e424c5e26 100644
--- a/tensorflow/python/training/queue_runner.py
+++ b/tensorflow/python/training/queue_runner.py
@@ -44,7 +44,8 @@ class QueueRunner(object):
"""
def __init__(self, queue=None, enqueue_ops=None, close_op=None,
- cancel_op=None, queue_runner_def=None):
+ cancel_op=None, queue_closed_exception_types=None,
+ queue_runner_def=None):
"""Create a QueueRunner.
On construction the `QueueRunner` adds an op to close the queue. That op
@@ -61,6 +62,11 @@ class QueueRunner(object):
enqueue_ops: List of enqueue ops to run in threads later.
close_op: Op to close the queue. Pending enqueue ops are preserved.
cancel_op: Op to close the queue and cancel pending enqueue ops.
+ queue_closed_exception_types: Optional tuple of Exception types that
+ indicate that the queue has been closed when raised during an enqueue
+ operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common
+ case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
+ when some of the enqueue ops may dequeue from other Queues.
queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
recreates the QueueRunner from its contents. `queue_runner_def` and the
other arguments are mutually exclusive.
@@ -75,8 +81,10 @@ class QueueRunner(object):
raise ValueError("queue_runner_def and queue are mutually exclusive.")
self._init_from_proto(queue_runner_def)
else:
- self._init_from_args(queue=queue, enqueue_ops=enqueue_ops,
- close_op=close_op, cancel_op=cancel_op)
+ self._init_from_args(
+ queue=queue, enqueue_ops=enqueue_ops,
+ close_op=close_op, cancel_op=cancel_op,
+ queue_closed_exception_types=queue_closed_exception_types)
# Protect the count of runs to wait for.
self._lock = threading.Lock()
self._runs = 0
@@ -84,7 +92,7 @@ class QueueRunner(object):
self._exceptions_raised = []
def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
- cancel_op=None):
+ cancel_op=None, queue_closed_exception_types=None):
"""Create a QueueRunner from arguments.
Args:
@@ -92,10 +100,14 @@ class QueueRunner(object):
enqueue_ops: List of enqueue ops to run in threads later.
close_op: Op to close the queue. Pending enqueue ops are preserved.
cancel_op: Op to close the queue and cancel pending enqueue ops.
+ queue_closed_exception_types: Tuple of exception types, which indicate
+ the queue has been safely closed.
Raises:
ValueError: If `queue` or `enqueue_ops` are not provided when not
restoring from `queue_runner_def`.
+ TypeError: If `queue_closed_exception_types` is provided, but is not
+ a non-empty tuple of error types (subclasses of `tf.errors.OpError`).
"""
if not queue or not enqueue_ops:
raise ValueError("Must provide queue and enqueue_ops.")
@@ -103,6 +115,16 @@ class QueueRunner(object):
self._enqueue_ops = enqueue_ops
self._close_op = close_op
self._cancel_op = cancel_op
+ if queue_closed_exception_types is not None:
+ if (not isinstance(queue_closed_exception_types, tuple)
+ or not queue_closed_exception_types
+ or not all(issubclass(t, errors.OpError)
+ for t in queue_closed_exception_types)):
+ raise TypeError(
+ "queue_closed_exception_types, when provided, "
+ "must be a non-empty list of tf.error types, but saw: %s"
+ % queue_closed_exception_types)
+ self._queue_closed_exception_types = queue_closed_exception_types
# Close when no more will be produced, but pending enqueues should be
# preserved.
if self._close_op is None:
@@ -111,6 +133,11 @@ class QueueRunner(object):
# to unblock everything so we can cleanly exit.
if self._cancel_op is None:
self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
+ if not self._queue_closed_exception_types:
+ self._queue_closed_exception_types = (errors.OutOfRangeError,)
+ else:
+ self._queue_closed_exception_types = tuple(
+ self._queue_closed_exception_types)
def _init_from_proto(self, queue_runner_def):
"""Create a QueueRunner from `QueueRunnerDef`.
@@ -125,6 +152,13 @@ class QueueRunner(object):
in queue_runner_def.enqueue_op_name]
self._close_op = g.as_graph_element(queue_runner_def.close_op_name)
self._cancel_op = g.as_graph_element(queue_runner_def.cancel_op_name)
+ self._queue_closed_exception_types = tuple(
+ errors.exception_type_from_error_code(code)
+ for code in queue_runner_def.queue_closed_exception_types)
+ # Legacy support for old QueueRunnerDefs created before this field
+ # was added.
+ if not self._queue_closed_exception_types:
+ self._queue_closed_exception_types = (errors.OutOfRangeError,)
@property
def queue(self):
@@ -143,6 +177,10 @@ class QueueRunner(object):
return self._cancel_op
@property
+ def queue_closed_exception_types(self):
+ return self._queue_closed_exception_types
+
+ @property
def exceptions_raised(self):
"""Exceptions raised but not handled by the `QueueRunner` threads.
@@ -185,7 +223,7 @@ class QueueRunner(object):
break
try:
sess.run(enqueue_op)
- except errors.OutOfRangeError:
+ except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
# This exception indicates that a queue was closed.
with self._lock:
self._runs -= 1
@@ -290,6 +328,9 @@ class QueueRunner(object):
queue_runner_def.enqueue_op_name.append(enqueue_op.name)
queue_runner_def.close_op_name = self.close_op.name
queue_runner_def.cancel_op_name = self.cancel_op.name
+ queue_runner_def.queue_closed_exception_types.extend([
+ errors.error_code_from_exception_type(cls)
+ for cls in self._queue_closed_exception_types])
return queue_runner_def
@staticmethod
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 6a19246ff1..0b4cb8cefd 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -259,5 +259,41 @@ class QueueRunnerTest(tf.test.TestCase):
# The variable should be 3.
self.assertEqual(3, var.eval())
+ def testQueueRunnerSerializationRoundTrip(self):
+ graph = tf.Graph()
+ with graph.as_default():
+ queue = tf.FIFOQueue(10, tf.float32, name="queue")
+ enqueue_op = tf.no_op(name="enqueue")
+ close_op = tf.no_op(name="close")
+ cancel_op = tf.no_op(name="cancel")
+ qr0 = tf.train.QueueRunner(
+ queue, [enqueue_op], close_op, cancel_op,
+ queue_closed_exception_types=(
+ tf.errors.OutOfRangeError, tf.errors.CancelledError))
+ qr0_proto = tf.train.QueueRunner.to_proto(qr0)
+ qr0_recon = tf.train.QueueRunner.from_proto(qr0_proto)
+ self.assertEqual("queue", qr0_recon.queue.name)
+ self.assertEqual(1, len(qr0_recon.enqueue_ops))
+ self.assertEqual(enqueue_op, qr0_recon.enqueue_ops[0])
+ self.assertEqual(close_op, qr0_recon.close_op)
+ self.assertEqual(cancel_op, qr0_recon.cancel_op)
+ self.assertEqual(
+ (tf.errors.OutOfRangeError, tf.errors.CancelledError),
+ qr0_recon.queue_closed_exception_types)
+
+ # Assert we reconstruct an OutOfRangeError for QueueRunners
+ # created before QueueRunnerDef had a queue_closed_exception_types field.
+ del qr0_proto.queue_closed_exception_types[:]
+ qr0_legacy_recon = tf.train.QueueRunner.from_proto(qr0_proto)
+ self.assertEqual("queue", qr0_legacy_recon.queue.name)
+ self.assertEqual(1, len(qr0_legacy_recon.enqueue_ops))
+ self.assertEqual(enqueue_op, qr0_legacy_recon.enqueue_ops[0])
+ self.assertEqual(close_op, qr0_legacy_recon.close_op)
+ self.assertEqual(cancel_op, qr0_legacy_recon.cancel_op)
+ self.assertEqual(
+ (tf.errors.OutOfRangeError,),
+ qr0_legacy_recon.queue_closed_exception_types)
+
+
if __name__ == "__main__":
tf.test.main()