aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-06-08 02:49:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 02:52:51 -0700
commit16c1d25110e48b8cecbf61ea8e15a7c9da26dd83 (patch)
treebfb37d0adfc95c7a022efdbe82e1dd68e284de98
parentc2493ed5aa9eaf375d88331c7cdb70e428614dc8 (diff)
Removes error message from queues in eager (leaves the one in queuerunners).
There's no real reason to not support queues in eager for people using them without using queue runners. PiperOrigin-RevId: 199770626
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h6
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py20
-rw-r--r--tensorflow/python/ops/data_flow_ops.py46
4 files changed, 34 insertions, 39 deletions
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 2a43a31c02..b410ea175b 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -79,6 +79,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
+ params.cancellation_manager = &cm_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index f78d197fd5..c41a0972b1 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -76,6 +77,11 @@ class KernelAndDevice {
const DataTypeVector& output_dtypes() { return output_dtypes_; }
private:
+ // TODO(apassos) Consider a shared cancellation manager. Note that this
+ // cancellation manager is not useful to actually cancel anything, and is
+ // provided here only for the few kernels which can't handle one being
+ // missing.
+ CancellationManager cm_;
std::unique_ptr<OpKernel> kernel_;
Device* device_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index ce73e7ad3e..14a336c688 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase):
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
self.assertEqual(4, q.size().eval())
+ @test_util.run_in_graph_and_eager_modes()
def testMultipleDequeues(self):
- with self.test_session() as session:
- q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
- q.enqueue_many([[1, 2, 3]]).run()
- a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()])
- self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue_many([[1, 2, 3]]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testQueuesDontShare(self):
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
with self.test_session():
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 62c5adc385..abf597ca55 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
@@ -129,11 +130,6 @@ class QueueBase(object):
@{tf.RandomShuffleQueue} for concrete
implementations of this class, and instructions on how to create
them.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self, dtypes, shapes, names, queue_ref):
@@ -157,12 +153,7 @@ class QueueBase(object):
Raises:
ValueError: If one of the arguments is invalid.
- RuntimeError: If eager execution is enabled.
"""
- if context.executing_eagerly():
- raise RuntimeError(
- "Queues are not supported when eager execution is enabled. "
- "Instead, please use tf.data to get data into your model.")
self._dtypes = dtypes
if shapes is not None:
if len(shapes) != len(dtypes):
@@ -179,6 +170,8 @@ class QueueBase(object):
self._queue_ref = queue_ref
if context.executing_eagerly():
self._name = context.context().scope_name
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ queue_ref, None)
else:
self._name = self._queue_ref.op.name.split("/")[-1]
@@ -605,6 +598,11 @@ class QueueBase(object):
else:
return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
+def _shared_name(shared_name):
+ if context.executing_eagerly():
+ return str(ops.uid())
+ return shared_name
+
@tf_export("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
@@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase):
min_after_dequeue=min_after_dequeue,
seed=seed1,
seed2=seed2,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -702,11 +695,6 @@ class FIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -752,7 +740,7 @@ class FIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -843,11 +826,6 @@ class PriorityQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -899,7 +877,7 @@ class PriorityQueue(QueueBase):
component_types=types,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
priority_dtypes = [_dtypes.int64] + types