aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/queue_interface.h7
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc8
-rw-r--r--tensorflow/core/kernels/fifo_queue.h1
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.cc8
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.h1
-rw-r--r--tensorflow/core/kernels/queue_base.h25
-rw-r--r--tensorflow/core/kernels/queue_ops.cc80
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc9
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc39
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/random_shuffle_queue_test.py8
-rw-r--r--tensorflow/python/ops/data_flow_grad.py1
-rw-r--r--tensorflow/python/ops/data_flow_ops.py45
13 files changed, 220 insertions, 13 deletions
diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h
index fd9a19c8b2..dde1f700ea 100644
--- a/tensorflow/core/framework/queue_interface.h
+++ b/tensorflow/core/framework/queue_interface.h
@@ -54,8 +54,13 @@ class QueueInterface : public ResourceBase {
virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0;
// Same as above, but the stashed function object will attempt to dequeue
- // num_elements items.
+ // num_elements items. If allow_small_batch is true, and the Queue is
+ // closed but at least 1 element is available, there is no blocking
+ // and between 1 and num_elements items are immediately returned.
+ // If the queue does not support the allow_small_batch flag will
+ // return an Unimplemented error.
virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) = 0;
// Signals that no more elements will be enqueued, and optionally
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index d7102e0c98..d58f750a67 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -190,7 +190,15 @@ void FIFOQueue::TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) {
}
void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (!specified_shapes()) {
ctx->SetStatus(
errors::InvalidArgument("FIFOQueue's DequeueMany requires the "
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index 1804f7cd26..c2c8cd57be 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -44,6 +44,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
DoneCallback callback) override;
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc
index eba4677b6d..ec52dc8ce6 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.cc
+++ b/tensorflow/core/kernels/padding_fifo_queue.cc
@@ -65,7 +65,15 @@ Status PaddingFIFOQueue::GetElementComponent(
}
void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (num_elements == 0) {
Tuple tuple;
tuple.reserve(num_components());
diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h
index 23b8b2b198..eb0937a096 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.h
+++ b/tensorflow/core/kernels/padding_fifo_queue.h
@@ -43,6 +43,7 @@ class PaddingFIFOQueue : public FIFOQueue {
// Implementations of QueueInterface methods --------------------------------
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 078c405902..83bffbe868 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -31,6 +31,10 @@ limitations under the License.
namespace tensorflow {
+namespace barrier {
+class Barrier;
+} // namespace barrier
+
// Functionality common to asynchronous QueueInterface implementations.
class QueueBase : public QueueInterface {
public:
@@ -65,6 +69,19 @@ class QueueBase : public QueueInterface {
int32 capacity() const { return capacity_; }
+ bool closed() {
+ mutex_lock lock(mu_);
+ return closed_;
+ }
+
+ // Copies the index^th slice (in the first dimension) of parent into element.
+ static Status CopySliceToElement(const Tensor& parent, Tensor* element,
+ int64 index);
+
+ // Copies element into the index^th slice (in the first dimension) of parent.
+ static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
+ int64 index);
+
protected:
enum Action { kEnqueue, kDequeue };
enum RunResult { kNoProgress, kProgress, kComplete };
@@ -98,14 +115,6 @@ class QueueBase : public QueueInterface {
return shape;
}
- // Copies the index^th slice (in the first dimension) of parent into element.
- static Status CopySliceToElement(const Tensor& parent, Tensor* element,
- int64 index);
-
- // Copies element into the index^th slice (in the first dimension) of parent.
- static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
- int64 index);
-
void Cancel(Action action, CancellationManager* cancellation_manager,
CancellationToken token);
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index ac2a1d6e64..1f2952e194 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -226,7 +226,8 @@ class DequeueManyOp : public QueueAccessOpKernel {
callback);
queue->TryDequeueMany(
- num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ num_elements, ctx, false /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
@@ -251,6 +252,83 @@ class DequeueManyOp : public QueueAccessOpKernel {
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+class DequeueUpToOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueUpToOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(
+ ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueUpToOp must request a positive number "
+ "of elements"),
+ callback);
+
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+
+ queue->TryDequeueMany(
+ num_elements, ctx, true /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components),
+ callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+ }
+
+ ~DequeueUpToOp() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
+ DequeueUpToOp);
+
// Defines a QueueCloseOp, which closes the given Queue. Closing a
// Queue signals that no more elements will be enqueued in it.
//
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index 1c96d6234c..a002328869 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -53,6 +53,7 @@ class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
DoneCallback callback) override;
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
@@ -256,7 +257,15 @@ void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
}
void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (!specified_shapes()) {
ctx->SetStatus(
errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the "
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index cef74ca8ac..4d24339cc3 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -285,6 +285,9 @@ REGISTER_OP("QueueDequeueMany")
.Doc(R"doc(
Dequeues n tuples of one or more tensors from the given queue.
+If the queue is closed and there are fewer than n elements, then an
+OutOfRange error is returned.
+
This operation concatenates queue-element component tensors along the
0th dimension to make a single component tensor. All of the components
in the dequeued tuple will have size n in the 0th dimension.
@@ -305,6 +308,42 @@ timeout_ms: If the queue has fewer than n elements, this operation
Note: This option is not supported yet.
)doc");
+REGISTER_OP("QueueDequeueUpTo")
+ .Input("handle: Ref(string)")
+ .Input("n: int32")
+ .Output("components: component_types")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Dequeues n tuples of one or more tensors from the given queue.
+
+This operation is not supported by all queues. If a queue does not support
+DequeueUpTo, then an Unimplemented error is returned.
+
+If the queue is closed and there are more than 0 but less than n elements
+remaining, then instead of returning an OutOfRange error like
+QueueDequeueMany, the remaining elements are returned immediately. If the queue
+is closed and there are 0 elements left in the queue, then an OutOfRange
+error is returned just like in QueueDequeueMany. Otherwise the behavior
+is identical to QueueDequeueMany:
+
+This operation concatenates queue-element component tensors along the
+0th dimension to make a single component tensor. All of the components
+in the dequeued tuple will have size n in the 0th dimension.
+
+This operation has k outputs, where k is the number of components in
+the tuples stored in the given queue, and output i is the ith
+component of the dequeued tuple.
+
+handle: The handle to a queue.
+n: The number of tuples to dequeue.
+components: One or more tensors that were dequeued as a tuple.
+component_types: The type of each component in a tuple.
+timeout_ms: If the queue has fewer than n elements, this operation
+ will block for up to timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
REGISTER_OP("QueueClose")
.Input("handle: Ref(string)")
.Attr("cancel_pending_enqueues: bool = false")
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b8c9d2699b..216fd150e4 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -556,6 +556,7 @@ tf_gen_op_wrapper_py(
"QueueClose",
"QueueDequeue",
"QueueDequeueMany",
+ "QueueDequeueUpTo",
"QueueEnqueue",
"QueueEnqueueMany",
"QueueSize",
diff --git a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
index 2383f88836..7e8229cd3e 100644
--- a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
+++ b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py
@@ -1067,6 +1067,14 @@ class RandomShuffleQueueTest(tf.test.TestCase):
thread.join()
self.assertItemsEqual(elem, results)
+ def testDequeueUpToFails(self):
+ with self.test_session():
+ q = tf.RandomShuffleQueue(10, 0, tf.float32, shapes=())
+ dequeued_t = q.dequeue_up_to(0)
+ with self.assertRaisesOpError(
+ r"Dequeue: Queue does not support small batches"):
+ dequeued_t.eval()
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
index dedecaa375..1d7ed93354 100644
--- a/tensorflow/python/ops/data_flow_grad.py
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -69,6 +69,7 @@ ops.NoGradient("QueueEnqueue")
ops.NoGradient("QueueEnqueueMany")
ops.NoGradient("QueueDequeue")
ops.NoGradient("QueueDequeueMany")
+ops.NoGradient("QueueDequeueUpTo")
ops.NoGradient("QueueClose")
ops.NoGradient("QueueSize")
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 3f72ccf5cd..78826c998f 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -285,8 +285,8 @@ class QueueBase(object):
the 0th dimension to make a single component tensor. All of the
components in the dequeued tuple will have size `n` in the 0th dimension.
- If the queue contains fewer than `n` elements when this operation
- executes, it will block until `n` elements have been dequeued.
+ If the queue is closed and there are less than `n` elements left, then an
+ `OutOfRange` exception is raised.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
@@ -299,7 +299,7 @@ class QueueBase(object):
name = "%s_DequeueMany" % self._name
ret = gen_data_flow_ops._queue_dequeue_many(
- self._queue_ref, n, self._dtypes, name=name)
+ self._queue_ref, n=n, component_types=self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
@@ -310,6 +310,44 @@ class QueueBase(object):
return ret if len(ret) != 1 else ret[0]
+ def dequeue_up_to(self, n, name=None):
+ """Dequeues and concatenates `n` elements from this queue.
+
+ **Note** This operation is not supported by all queues. If a queue does not
+ support DequeueUpTo, then an Unimplemented exception is raised.
+
+ This operation concatenates queue-element component tensors along the
+ 0th dimension to make a single component tensor. All of the components
+ in the dequeued tuple will have size `n` in the 0th dimension.
+
+ If the queue is closed and there are more than `0` but less than `n`
+ elements remaining, then instead of raising an `OutOfRange` exception like
+ `dequeue_many`, the remaining elements are returned immediately.
+ If the queue is closed and there are `0` elements left in the queue, then
+ an `OutOfRange` exception is raised just like in `dequeue_many`.
+ Otherwise the behavior is identical to `dequeue_many`:
+
+ Args:
+ n: A scalar `Tensor` containing the number of elements to dequeue.
+ name: A name for the operation (optional).
+
+ Returns:
+ The tuple of concatenated tensors that was dequeued.
+ """
+ if name is None:
+ name = "%s_DequeueUpTo" % self._name
+
+ ret = gen_data_flow_ops._queue_dequeue_up_to(
+ self._queue_ref, n=n, component_types=self._dtypes, name=name)
+
+ # NOTE(mrry): Not using a shape function because we need access to
+ # the Queue object.
+ op = ret[0].op
+ for output, shape in zip(op.values(), self._shapes):
+ output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
+
+ return ret if len(ret) != 1 else ret[0]
+
def close(self, cancel_pending_enqueues=False, name=None):
"""Closes this queue.
@@ -561,6 +599,7 @@ def _ScalarToVoidShape(op):
# Queue class to provide shape information.
ops.RegisterShape("QueueDequeue")(common_shapes.unknown_shape)
ops.RegisterShape("QueueDequeueMany")(common_shapes.unknown_shape)
+ops.RegisterShape("QueueDequeueUpTo")(common_shapes.unknown_shape)
ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape)
ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape)
ops.RegisterShape("QueueClose")(_ScalarToVoidShape)