diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2016-08-30 20:55:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-30 22:02:18 -0700 |
commit | 7e7e0d68ceca42900ddf378a07fff2bec6dec88d (patch) | |
tree | c7db31a22d7ae628030d29a781d932016371c7ff | |
parent | 4f7a43418513d326e98c7f84fe0835ded55d0189 (diff) |
Add new QueueRunner optional argument: queue_closed_exception_types.
* This is a backwards compatible change.
* Includes some extra helper functions in the tf.errors module.
* Includes extension to the QueueRunner proto
* Includes tests that QueueRunner.{from,to}_proto are backwards compatible.
Change: 131791267
-rw-r--r-- | tensorflow/core/protobuf/queue_runner.proto | 6 | ||||
-rw-r--r-- | tensorflow/python/framework/errors.py | 13 | ||||
-rw-r--r-- | tensorflow/python/training/queue_runner.py | 51 | ||||
-rw-r--r-- | tensorflow/python/training/queue_runner_test.py | 36 |
4 files changed, 100 insertions, 6 deletions
diff --git a/tensorflow/core/protobuf/queue_runner.proto b/tensorflow/core/protobuf/queue_runner.proto index 963e12784f..05a48d0acf 100644 --- a/tensorflow/core/protobuf/queue_runner.proto +++ b/tensorflow/core/protobuf/queue_runner.proto @@ -6,6 +6,8 @@ option java_outer_classname = "QueueRunnerProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; +import "tensorflow/core/lib/core/error_codes.proto"; + // Protocol buffer representing a QueueRunner. message QueueRunnerDef { // Queue name. @@ -19,4 +21,8 @@ message QueueRunnerDef { // The operation to run to cancel the queue. string cancel_op_name = 4; + + // A list of exception types considered to signal a safely closed queue + // if raised during enqueue operations. + repeated error.Code queue_closed_exception_types = 5; } diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py index db21f5895c..71c4d97401 100644 --- a/tensorflow/python/framework/errors.py +++ b/tensorflow/python/framework/errors.py @@ -428,10 +428,21 @@ _CODE_TO_EXCEPTION_CLASS = { DATA_LOSS: DataLossError, } +_EXCEPTION_CLASS_TO_CODE = dict(( + (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items())) + + +def exception_type_from_error_code(error_code): + return _CODE_TO_EXCEPTION_CLASS[error_code] + + +def error_code_from_exception_type(cls): + return _EXCEPTION_CLASS_TO_CODE[cls] + def _make_specific_exception(node_def, op, message, error_code): try: - exc_type = _CODE_TO_EXCEPTION_CLASS[error_code] + exc_type = exception_type_from_error_code(error_code) return exc_type(node_def, op, message) except KeyError: warnings.warn("Unknown error code: %d" % error_code) diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py index 2ae07f0bc5..4e424c5e26 100644 --- a/tensorflow/python/training/queue_runner.py +++ b/tensorflow/python/training/queue_runner.py @@ -44,7 +44,8 @@ class QueueRunner(object): """ def __init__(self, queue=None, enqueue_ops=None, close_op=None, - cancel_op=None, queue_runner_def=None): + cancel_op=None, queue_closed_exception_types=None, + queue_runner_def=None): """Create a QueueRunner. On construction the `QueueRunner` adds an op to close the queue. That op @@ -61,6 +62,11 @@ class QueueRunner(object): enqueue_ops: List of enqueue ops to run in threads later. close_op: Op to close the queue. Pending enqueue ops are preserved. cancel_op: Op to close the queue and cancel pending enqueue ops. + queue_closed_exception_types: Optional tuple of Exception types that + indicate that the queue has been closed when raised during an enqueue + operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common + case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`, + when some of the enqueue ops may dequeue from other Queues. queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified, recreates the QueueRunner from its contents. `queue_runner_def` and the other arguments are mutually exclusive. @@ -75,8 +81,10 @@ class QueueRunner(object): raise ValueError("queue_runner_def and queue are mutually exclusive.") self._init_from_proto(queue_runner_def) else: - self._init_from_args(queue=queue, enqueue_ops=enqueue_ops, - close_op=close_op, cancel_op=cancel_op) + self._init_from_args( + queue=queue, enqueue_ops=enqueue_ops, + close_op=close_op, cancel_op=cancel_op, + queue_closed_exception_types=queue_closed_exception_types) # Protect the count of runs to wait for. self._lock = threading.Lock() self._runs = 0 @@ -84,7 +92,7 @@ class QueueRunner(object): self._exceptions_raised = [] def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, - cancel_op=None): + cancel_op=None, queue_closed_exception_types=None): """Create a QueueRunner from arguments. Args: @@ -92,10 +100,14 @@ class QueueRunner(object): enqueue_ops: List of enqueue ops to run in threads later. close_op: Op to close the queue. Pending enqueue ops are preserved. cancel_op: Op to close the queue and cancel pending enqueue ops. + queue_closed_exception_types: Tuple of exception types, which indicate + the queue has been safely closed. Raises: ValueError: If `queue` or `enqueue_ops` are not provided when not restoring from `queue_runner_def`. + TypeError: If `queue_closed_exception_types` is provided, but is not + a non-empty tuple of error types (subclasses of `tf.errors.OpError`). """ if not queue or not enqueue_ops: raise ValueError("Must provide queue and enqueue_ops.") @@ -103,6 +115,16 @@ class QueueRunner(object): self._enqueue_ops = enqueue_ops self._close_op = close_op self._cancel_op = cancel_op + if queue_closed_exception_types is not None: + if (not isinstance(queue_closed_exception_types, tuple) + or not queue_closed_exception_types + or not all(issubclass(t, errors.OpError) + for t in queue_closed_exception_types)): + raise TypeError( + "queue_closed_exception_types, when provided, " + "must be a non-empty list of tf.error types, but saw: %s" + % queue_closed_exception_types) + self._queue_closed_exception_types = queue_closed_exception_types # Close when no more will be produced, but pending enqueues should be # preserved. if self._close_op is None: @@ -111,6 +133,11 @@ class QueueRunner(object): # to unblock everything so we can cleanly exit. if self._cancel_op is None: self._cancel_op = self._queue.close(cancel_pending_enqueues=True) + if not self._queue_closed_exception_types: + self._queue_closed_exception_types = (errors.OutOfRangeError,) + else: + self._queue_closed_exception_types = tuple( + self._queue_closed_exception_types) def _init_from_proto(self, queue_runner_def): """Create a QueueRunner from `QueueRunnerDef`. @@ -125,6 +152,13 @@ class QueueRunner(object): in queue_runner_def.enqueue_op_name] self._close_op = g.as_graph_element(queue_runner_def.close_op_name) self._cancel_op = g.as_graph_element(queue_runner_def.cancel_op_name) + self._queue_closed_exception_types = tuple( + errors.exception_type_from_error_code(code) + for code in queue_runner_def.queue_closed_exception_types) + # Legacy support for old QueueRunnerDefs created before this field + # was added. + if not self._queue_closed_exception_types: + self._queue_closed_exception_types = (errors.OutOfRangeError,) @property def queue(self): @@ -143,6 +177,10 @@ class QueueRunner(object): return self._cancel_op @property + def queue_closed_exception_types(self): + return self._queue_closed_exception_types + + @property def exceptions_raised(self): """Exceptions raised but not handled by the `QueueRunner` threads. @@ -185,7 +223,7 @@ class QueueRunner(object): break try: sess.run(enqueue_op) - except errors.OutOfRangeError: + except self._queue_closed_exception_types: # pylint: disable=catching-non-exception # This exception indicates that a queue was closed. with self._lock: self._runs -= 1 @@ -290,6 +328,9 @@ class QueueRunner(object): queue_runner_def.enqueue_op_name.append(enqueue_op.name) queue_runner_def.close_op_name = self.close_op.name queue_runner_def.cancel_op_name = self.cancel_op.name + queue_runner_def.queue_closed_exception_types.extend([ + errors.error_code_from_exception_type(cls) + for cls in self._queue_closed_exception_types]) return queue_runner_def @staticmethod diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py index 6a19246ff1..0b4cb8cefd 100644 --- a/tensorflow/python/training/queue_runner_test.py +++ b/tensorflow/python/training/queue_runner_test.py @@ -259,5 +259,41 @@ class QueueRunnerTest(tf.test.TestCase): # The variable should be 3. self.assertEqual(3, var.eval()) + def testQueueRunnerSerializationRoundTrip(self): + graph = tf.Graph() + with graph.as_default(): + queue = tf.FIFOQueue(10, tf.float32, name="queue") + enqueue_op = tf.no_op(name="enqueue") + close_op = tf.no_op(name="close") + cancel_op = tf.no_op(name="cancel") + qr0 = tf.train.QueueRunner( + queue, [enqueue_op], close_op, cancel_op, + queue_closed_exception_types=( + tf.errors.OutOfRangeError, tf.errors.CancelledError)) + qr0_proto = tf.train.QueueRunner.to_proto(qr0) + qr0_recon = tf.train.QueueRunner.from_proto(qr0_proto) + self.assertEqual("queue", qr0_recon.queue.name) + self.assertEqual(1, len(qr0_recon.enqueue_ops)) + self.assertEqual(enqueue_op, qr0_recon.enqueue_ops[0]) + self.assertEqual(close_op, qr0_recon.close_op) + self.assertEqual(cancel_op, qr0_recon.cancel_op) + self.assertEqual( + (tf.errors.OutOfRangeError, tf.errors.CancelledError), + qr0_recon.queue_closed_exception_types) + + # Assert we reconstruct an OutOfRangeError for QueueRunners + # created before QueueRunnerDef had a queue_closed_exception_types field. + del qr0_proto.queue_closed_exception_types[:] + qr0_legacy_recon = tf.train.QueueRunner.from_proto(qr0_proto) + self.assertEqual("queue", qr0_legacy_recon.queue.name) + self.assertEqual(1, len(qr0_legacy_recon.enqueue_ops)) + self.assertEqual(enqueue_op, qr0_legacy_recon.enqueue_ops[0]) + self.assertEqual(close_op, qr0_legacy_recon.close_op) + self.assertEqual(cancel_op, qr0_legacy_recon.cancel_op) + self.assertEqual( + (tf.errors.OutOfRangeError,), + qr0_legacy_recon.queue_closed_exception_types) + + if __name__ == "__main__": tf.test.main() |