diff options
4 files changed, 34 insertions, 39 deletions
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 2a43a31c02..b410ea175b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -79,6 +79,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, params.function_library = flib_; params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; + params.cancellation_manager = &cm_; if (stats != nullptr) { params.track_allocations = true; } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index f78d197fd5..c41a0972b1 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -76,6 +77,11 @@ class KernelAndDevice { const DataTypeVector& output_dtypes() { return output_dtypes_; } private: + // TODO(apassos) Consider a shared cancellation manager. Note that this + // cancellation manager is not useful to actually cancel anything, and is + // provided here only for the few kernels which can't handle one being + // missing. + CancellationManager cm_; std::unique_ptr<OpKernel> kernel_; Device* device_; FunctionLibraryRuntime* flib_; diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index ce73e7ad3e..14a336c688 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase): q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run() self.assertEqual(4, q.size().eval()) + @test_util.run_in_graph_and_eager_modes() def testMultipleDequeues(self): - with self.test_session() as session: - q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) - q.enqueue_many([[1, 2, 3]]).run() - a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()]) - self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue_many([[1, 2, 3]])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + @test_util.run_in_graph_and_eager_modes() + def testQueuesDontShare(self): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): with self.test_session(): diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 62c5adc385..abf597ca55 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_data_flow_ops import * @@ -129,11 +130,6 @@ class QueueBase(object): @{tf.RandomShuffleQueue} for concrete implementations of this class, and instructions on how to create them. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, dtypes, shapes, names, queue_ref): @@ -157,12 +153,7 @@ class QueueBase(object): Raises: ValueError: If one of the arguments is invalid. - RuntimeError: If eager execution is enabled. """ - if context.executing_eagerly(): - raise RuntimeError( - "Queues are not supported when eager execution is enabled. " - "Instead, please use tf.data to get data into your model.") self._dtypes = dtypes if shapes is not None: if len(shapes) != len(dtypes): @@ -179,6 +170,8 @@ class QueueBase(object): self._queue_ref = queue_ref if context.executing_eagerly(): self._name = context.context().scope_name + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + queue_ref, None) else: self._name = self._queue_ref.op.name.split("/")[-1] @@ -605,6 +598,11 @@ class QueueBase(object): else: return gen_data_flow_ops.queue_size(self._queue_ref, name=name) +def _shared_name(shared_name): + if context.executing_eagerly(): + return str(ops.uid()) + return shared_name + @tf_export("RandomShuffleQueue") class RandomShuffleQueue(QueueBase): @@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase): min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -702,11 +695,6 @@ class FIFOQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -752,7 +740,7 @@ class FIFOQueue(QueueBase): component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase): component_types=dtypes, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) @@ -843,11 +826,6 @@ class PriorityQueue(QueueBase): See @{tf.QueueBase} for a description of the methods on this class. - - @compatibility(eager) - Queues are not compatible with eager execution. Instead, please - use `tf.data` to get data into your model. - @end_compatibility """ def __init__(self, @@ -899,7 +877,7 @@ class PriorityQueue(QueueBase): component_types=types, shapes=shapes, capacity=capacity, - shared_name=shared_name, + shared_name=_shared_name(shared_name), name=name) priority_dtypes = [_dtypes.int64] + types |