diff options
-rw-r--r-- | tensorflow/core/framework/queue_interface.h | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/fifo_queue.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/fifo_queue.h | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.h | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_base.h | 25 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_ops.cc | 80 | ||||
-rw-r--r-- | tensorflow/core/kernels/random_shuffle_queue_op.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 39 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/random_shuffle_queue_test.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/data_flow_grad.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/data_flow_ops.py | 45 |
13 files changed, 220 insertions, 13 deletions
diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h index fd9a19c8b2..dde1f700ea 100644 --- a/tensorflow/core/framework/queue_interface.h +++ b/tensorflow/core/framework/queue_interface.h @@ -54,8 +54,13 @@ class QueueInterface : public ResourceBase { virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0; // Same as above, but the stashed function object will attempt to dequeue - // num_elements items. + // num_elements items. If allow_small_batch is true, and the Queue is + // closed but at least 1 element is available, there is no blocking + // and between 1 and num_elements items are immediately returned. + // If the queue does not support the allow_small_batch flag will + // return an Unimplemented error. virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) = 0; // Signals that no more elements will be enqueued, and optionally diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index d7102e0c98..d58f750a67 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -190,7 +190,15 @@ void FIFOQueue::TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) { } void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) { + if (allow_small_batch) { + ctx->SetStatus( + errors::Unimplemented("Dequeue: Queue does not support small batches")); + callback(Tuple()); + return; + } + if (!specified_shapes()) { ctx->SetStatus( errors::InvalidArgument("FIFOQueue's DequeueMany requires the " diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h index 1804f7cd26..c2c8cd57be 100644 --- a/tensorflow/core/kernels/fifo_queue.h +++ b/tensorflow/core/kernels/fifo_queue.h @@ -44,6 +44,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > { DoneCallback callback) override; void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) override; Status MatchesNodeDef(const NodeDef& node_def) override; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index eba4677b6d..ec52dc8ce6 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -65,7 +65,15 @@ Status PaddingFIFOQueue::GetElementComponent( } void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) { + if (allow_small_batch) { + ctx->SetStatus( + errors::Unimplemented("Dequeue: Queue does not support small batches")); + callback(Tuple()); + return; + } + if (num_elements == 0) { Tuple tuple; tuple.reserve(num_components()); diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h index 23b8b2b198..eb0937a096 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.h +++ b/tensorflow/core/kernels/padding_fifo_queue.h @@ -43,6 +43,7 @@ class PaddingFIFOQueue : public FIFOQueue { // Implementations of QueueInterface methods -------------------------------- void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) override; Status MatchesNodeDef(const NodeDef& node_def) override; diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h index 078c405902..83bffbe868 100644 --- a/tensorflow/core/kernels/queue_base.h +++ b/tensorflow/core/kernels/queue_base.h @@ -31,6 +31,10 @@ limitations under the License. namespace tensorflow { +namespace barrier { +class Barrier; +} // namespace barrier + // Functionality common to asynchronous QueueInterface implementations. class QueueBase : public QueueInterface { public: @@ -65,6 +69,19 @@ class QueueBase : public QueueInterface { int32 capacity() const { return capacity_; } + bool closed() { + mutex_lock lock(mu_); + return closed_; + } + + // Copies the index^th slice (in the first dimension) of parent into element. + static Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64 index); + + // Copies element into the index^th slice (in the first dimension) of parent. + static Status CopyElementToSlice(const Tensor& element, Tensor* parent, + int64 index); + protected: enum Action { kEnqueue, kDequeue }; enum RunResult { kNoProgress, kProgress, kComplete }; @@ -98,14 +115,6 @@ class QueueBase : public QueueInterface { return shape; } - // Copies the index^th slice (in the first dimension) of parent into element. - static Status CopySliceToElement(const Tensor& parent, Tensor* element, - int64 index); - - // Copies element into the index^th slice (in the first dimension) of parent. - static Status CopyElementToSlice(const Tensor& element, Tensor* parent, - int64 index); - void Cancel(Action action, CancellationManager* cancellation_manager, CancellationToken token); diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index ac2a1d6e64..1f2952e194 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -226,7 +226,8 @@ class DequeueManyOp : public QueueAccessOpKernel { callback); queue->TryDequeueMany( - num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + num_elements, ctx, false /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { if (!ctx->status().ok()) { callback(); return; @@ -251,6 +252,83 @@ class DequeueManyOp : public QueueAccessOpKernel { REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), DequeueManyOp); +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +class DequeueUpToOp : public QueueAccessOpKernel { + public: + explicit DequeueUpToOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat<int32>()(0); + + OP_REQUIRES_ASYNC( + ctx, num_elements >= 0, + errors::InvalidArgument("DequeueUpToOp must request a positive number " + "of elements"), + callback); + + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + + queue->TryDequeueMany( + num_elements, ctx, true /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), + callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); + } + + ~DequeueUpToOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU), + DequeueUpToOp); + // Defines a QueueCloseOp, which closes the given Queue. Closing a // Queue signals that no more elements will be enqueued in it. // diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index 1c96d6234c..a002328869 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -53,6 +53,7 @@ class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > { DoneCallback callback) override; void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) override; Status MatchesNodeDef(const NodeDef& node_def) override; @@ -256,7 +257,15 @@ void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx, } void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) { + if (allow_small_batch) { + ctx->SetStatus( + errors::Unimplemented("Dequeue: Queue does not support small batches")); + callback(Tuple()); + return; + } + if (!specified_shapes()) { ctx->SetStatus( errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the " diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index cef74ca8ac..4d24339cc3 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -285,6 +285,9 @@ REGISTER_OP("QueueDequeueMany") .Doc(R"doc( Dequeues n tuples of one or more tensors from the given queue. +If the queue is closed and there are fewer than n elements, then an +OutOfRange error is returned. + This operation concatenates queue-element component tensors along the 0th dimension to make a single component tensor. All of the components in the dequeued tuple will have size n in the 0th dimension. @@ -305,6 +308,42 @@ timeout_ms: If the queue has fewer than n elements, this operation Note: This option is not supported yet. )doc"); +REGISTER_OP("QueueDequeueUpTo") + .Input("handle: Ref(string)") + .Input("n: int32") + .Output("components: component_types") + .Attr("component_types: list(type) >= 1") + .Attr("timeout_ms: int = -1") + .Doc(R"doc( +Dequeues n tuples of one or more tensors from the given queue. + +This operation is not supported by all queues. If a queue does not support +DequeueUpTo, then an Unimplemented error is returned. + +If the queue is closed and there are more than 0 but less than n elements +remaining, then instead of returning an OutOfRange error like +QueueDequeueMany, the remaining elements are returned immediately. If the queue +is closed and there are 0 elements left in the queue, then an OutOfRange +error is returned just like in QueueDequeueMany. Otherwise the behavior +is identical to QueueDequeueMany: + +This operation concatenates queue-element component tensors along the +0th dimension to make a single component tensor. All of the components +in the dequeued tuple will have size n in the 0th dimension. + +This operation has k outputs, where k is the number of components in +the tuples stored in the given queue, and output i is the ith +component of the dequeued tuple. + +handle: The handle to a queue. +n: The number of tuples to dequeue. +components: One or more tensors that were dequeued as a tuple. +component_types: The type of each component in a tuple. +timeout_ms: If the queue has fewer than n elements, this operation + will block for up to timeout_ms milliseconds. + Note: This option is not supported yet. +)doc"); + REGISTER_OP("QueueClose") .Input("handle: Ref(string)") .Attr("cancel_pending_enqueues: bool = false") diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b8c9d2699b..216fd150e4 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -556,6 +556,7 @@ tf_gen_op_wrapper_py( "QueueClose", "QueueDequeue", "QueueDequeueMany", + "QueueDequeueUpTo", "QueueEnqueue", "QueueEnqueueMany", "QueueSize", diff --git a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py index 2383f88836..7e8229cd3e 100644 --- a/tensorflow/python/kernel_tests/random_shuffle_queue_test.py +++ b/tensorflow/python/kernel_tests/random_shuffle_queue_test.py @@ -1067,6 +1067,14 @@ class RandomShuffleQueueTest(tf.test.TestCase): thread.join() self.assertItemsEqual(elem, results) + def testDequeueUpToFails(self): + with self.test_session(): + q = tf.RandomShuffleQueue(10, 0, tf.float32, shapes=()) + dequeued_t = q.dequeue_up_to(0) + with self.assertRaisesOpError( + r"Dequeue: Queue does not support small batches"): + dequeued_t.eval() + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py index dedecaa375..1d7ed93354 100644 --- a/tensorflow/python/ops/data_flow_grad.py +++ b/tensorflow/python/ops/data_flow_grad.py @@ -69,6 +69,7 @@ ops.NoGradient("QueueEnqueue") ops.NoGradient("QueueEnqueueMany") ops.NoGradient("QueueDequeue") ops.NoGradient("QueueDequeueMany") +ops.NoGradient("QueueDequeueUpTo") ops.NoGradient("QueueClose") ops.NoGradient("QueueSize") diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 3f72ccf5cd..78826c998f 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -285,8 +285,8 @@ class QueueBase(object): the 0th dimension to make a single component tensor. All of the components in the dequeued tuple will have size `n` in the 0th dimension. - If the queue contains fewer than `n` elements when this operation - executes, it will block until `n` elements have been dequeued. + If the queue is closed and there are less than `n` elements left, then an + `OutOfRange` exception is raised. Args: n: A scalar `Tensor` containing the number of elements to dequeue. @@ -299,7 +299,7 @@ class QueueBase(object): name = "%s_DequeueMany" % self._name ret = gen_data_flow_ops._queue_dequeue_many( - self._queue_ref, n, self._dtypes, name=name) + self._queue_ref, n=n, component_types=self._dtypes, name=name) # NOTE(mrry): Not using a shape function because we need access to # the Queue object. @@ -310,6 +310,44 @@ class QueueBase(object): return ret if len(ret) != 1 else ret[0] + def dequeue_up_to(self, n, name=None): + """Dequeues and concatenates `n` elements from this queue. + + **Note** This operation is not supported by all queues. If a queue does not + support DequeueUpTo, then an Unimplemented exception is raised. + + This operation concatenates queue-element component tensors along the + 0th dimension to make a single component tensor. All of the components + in the dequeued tuple will have size `n` in the 0th dimension. + + If the queue is closed and there are more than `0` but less than `n` + elements remaining, then instead of raising an `OutOfRange` exception like + `dequeue_many`, the remaining elements are returned immediately. + If the queue is closed and there are `0` elements left in the queue, then + an `OutOfRange` exception is raised just like in `dequeue_many`. + Otherwise the behavior is identical to `dequeue_many`: + + Args: + n: A scalar `Tensor` containing the number of elements to dequeue. + name: A name for the operation (optional). + + Returns: + The tuple of concatenated tensors that was dequeued. + """ + if name is None: + name = "%s_DequeueUpTo" % self._name + + ret = gen_data_flow_ops._queue_dequeue_up_to( + self._queue_ref, n=n, component_types=self._dtypes, name=name) + + # NOTE(mrry): Not using a shape function because we need access to + # the Queue object. + op = ret[0].op + for output, shape in zip(op.values(), self._shapes): + output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) + + return ret if len(ret) != 1 else ret[0] + def close(self, cancel_pending_enqueues=False, name=None): """Closes this queue. @@ -561,6 +599,7 @@ def _ScalarToVoidShape(op): # Queue class to provide shape information. ops.RegisterShape("QueueDequeue")(common_shapes.unknown_shape) ops.RegisterShape("QueueDequeueMany")(common_shapes.unknown_shape) +ops.RegisterShape("QueueDequeueUpTo")(common_shapes.unknown_shape) ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape) ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape) ops.RegisterShape("QueueClose")(_ScalarToVoidShape) |