aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h6
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py20
-rw-r--r--tensorflow/python/ops/data_flow_ops.py46
4 files changed, 34 insertions, 39 deletions
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 2a43a31c02..b410ea175b 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -79,6 +79,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
+ params.cancellation_manager = &cm_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index f78d197fd5..c41a0972b1 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -76,6 +77,11 @@ class KernelAndDevice {
const DataTypeVector& output_dtypes() { return output_dtypes_; }
private:
+ // TODO(apassos) Consider a shared cancellation manager. Note that this
+ // cancellation manager is not useful to actually cancel anything, and is
+ // provided here only for the few kernels which can't handle one being
+ // missing.
+ CancellationManager cm_;
std::unique_ptr<OpKernel> kernel_;
Device* device_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index ce73e7ad3e..14a336c688 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase):
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
self.assertEqual(4, q.size().eval())
+ @test_util.run_in_graph_and_eager_modes()
def testMultipleDequeues(self):
- with self.test_session() as session:
- q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
- q.enqueue_many([[1, 2, 3]]).run()
- a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()])
- self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue_many([[1, 2, 3]]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testQueuesDontShare(self):
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
with self.test_session():
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 62c5adc385..abf597ca55 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
@@ -129,11 +130,6 @@ class QueueBase(object):
@{tf.RandomShuffleQueue} for concrete
implementations of this class, and instructions on how to create
them.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self, dtypes, shapes, names, queue_ref):
@@ -157,12 +153,7 @@ class QueueBase(object):
Raises:
ValueError: If one of the arguments is invalid.
- RuntimeError: If eager execution is enabled.
"""
- if context.executing_eagerly():
- raise RuntimeError(
- "Queues are not supported when eager execution is enabled. "
- "Instead, please use tf.data to get data into your model.")
self._dtypes = dtypes
if shapes is not None:
if len(shapes) != len(dtypes):
@@ -179,6 +170,8 @@ class QueueBase(object):
self._queue_ref = queue_ref
if context.executing_eagerly():
self._name = context.context().scope_name
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ queue_ref, None)
else:
self._name = self._queue_ref.op.name.split("/")[-1]
@@ -605,6 +598,11 @@ class QueueBase(object):
else:
return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
+def _shared_name(shared_name):
+ if context.executing_eagerly():
+ return str(ops.uid())
+ return shared_name
+
@tf_export("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
@@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase):
min_after_dequeue=min_after_dequeue,
seed=seed1,
seed2=seed2,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -702,11 +695,6 @@ class FIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -752,7 +740,7 @@ class FIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -843,11 +826,6 @@ class PriorityQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -899,7 +877,7 @@ class PriorityQueue(QueueBase):
component_types=types,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
priority_dtypes = [_dtypes.int64] + types