diff options
author | Alexandre Passos <apassos@google.com> | 2018-06-08 02:49:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-08 02:52:51 -0700 |
commit | 16c1d25110e48b8cecbf61ea8e15a7c9da26dd83 (patch) | |
tree | bfb37d0adfc95c7a022efdbe82e1dd68e284de98 | |
parent | c2493ed5aa9eaf375d88331c7cdb70e428614dc8 (diff) |
Removes error message from queues in eager (leaves the one in queuerunners).
There's no real reason to not support queues in eager for people using them
without using queue runners.
PiperOrigin-RevId: 199770626
4 files changed, 34 insertions, 39 deletions
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 2a43a31c02..b410ea175b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -79,6 +79,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, params.function_library = flib_; params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; + params.cancellation_manager = &cm_; if (stats != nullptr) { params.track_allocations = true; } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index f78d197fd5..c41a0972b1 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -76,6 +77,11 @@ class KernelAndDevice { const DataTypeVector& output_dtypes() { return output_dtypes_; } private: + // TODO(apassos) Consider a shared cancellation manager. Note that this + // cancellation manager is not useful to actually cancel anything, and is + // provided here only for the few kernels which can't handle one being + // missing. + CancellationManager cm_; std::unique_ptr<OpKernel> kernel_; Device* device_; FunctionLibraryRuntime* flib_; diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index ce73e7ad3e..14a336c688 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase): q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run() self.assertEqual(4, q.size().eval()) + @test_util.run_in_graph_and_eager_modes() def testMultipleDequeues(self): - with self.test_session() as session: - q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) - q.enqueue_many([[1, 2, 3]]).run() - a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()]) - self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue_many([[1, 2, 3]])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + @test_util.run_in_graph_and_eager_modes() + def testQueuesDontShare(self): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): with self.test_session(): diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 62c5adc385..abf597ca55 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * @@ -129,11 +130,6 @@ class QueueBase(object): @{tf.RandomShuffleQueue} for concrete implementations of this class, and instructions on how to create them. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, dtypes, shapes, names, queue_ref): @@ -157,12 +153,7 @@ class QueueBase(object): Raises: ValueError: If one of the arguments is invalid. - RuntimeError: If eager execution is enabled. """ - if context.executing_eagerly(): - raise RuntimeError( - "Queues are not supported when eager execution is enabled. " - "Instead, please use tf.data to get data into your model.") self._dtypes = dtypes if shapes is not None: if len(shapes) != len(dtypes): @@ -179,6 +170,8 @@ class QueueBase(object): self._queue_ref = queue_ref if context.executing_eagerly(): self._name = context.context().scope_name + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + queue_ref, None) else: self._name = self._queue_ref.op.name.split("/")[-1] @@ -605,6 +598,11 @@ class QueueBase(object): else: return gen_data_flow_ops.queue_size(self._queue_ref, name=name) +def _shared_name(shared_name): + if context.executing_eagerly(): + return str(ops.uid()) + return shared_name + @tf_export("RandomShuffleQueue") class RandomShuffleQueue(QueueBase): @@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase): min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -702,11 +695,6 @@ class FIFOQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -752,7 +740,7 @@ class FIFOQueue(QueueBase): component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase): component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -843,11 +826,6 @@ class PriorityQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -899,7 +877,7 @@ class PriorityQueue(QueueBase): component_types=types, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) priority_dtypes = [_dtypes.int64] + types |