diff options
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/fifo_queue.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/fifo_queue.h | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.h | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_base.h | 25 | ||||
-rw-r--r-- | tensorflow/core/kernels/queue_ops.cc | 80 | ||||
-rw-r--r-- | tensorflow/core/kernels/random_shuffle_queue_op.cc | 9 |
7 files changed, 123 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index d7102e0c98..d58f750a67 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -190,7 +190,15 @@ void FIFOQueue::TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) { } void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) { + if (allow_small_batch) { + ctx->SetStatus( + errors::Unimplemented("Dequeue: Queue does not support small batches")); + callback(Tuple()); + return; + } + if (!specified_shapes()) { ctx->SetStatus( errors::InvalidArgument("FIFOQueue's DequeueMany requires the " diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h index 1804f7cd26..c2c8cd57be 100644 --- a/tensorflow/core/kernels/fifo_queue.h +++ b/tensorflow/core/kernels/fifo_queue.h @@ -44,6 +44,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > { DoneCallback callback) override; void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) override; Status MatchesNodeDef(const NodeDef& node_def) override; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index eba4677b6d..ec52dc8ce6 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -65,7 +65,15 @@ Status PaddingFIFOQueue::GetElementComponent( } void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) { + if (allow_small_batch) { + ctx->SetStatus( + errors::Unimplemented("Dequeue: Queue does not support small batches")); + callback(Tuple()); + return; + } + if (num_elements == 0) { Tuple tuple; tuple.reserve(num_components()); diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h index 23b8b2b198..eb0937a096 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.h +++ b/tensorflow/core/kernels/padding_fifo_queue.h @@ -43,6 +43,7 @@ class PaddingFIFOQueue : public FIFOQueue { // Implementations of QueueInterface methods -------------------------------- void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) override; Status MatchesNodeDef(const NodeDef& node_def) override; diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h index 078c405902..83bffbe868 100644 --- a/tensorflow/core/kernels/queue_base.h +++ b/tensorflow/core/kernels/queue_base.h @@ -31,6 +31,10 @@ limitations under the License. namespace tensorflow { +namespace barrier { +class Barrier; +} // namespace barrier + // Functionality common to asynchronous QueueInterface implementations. class QueueBase : public QueueInterface { public: @@ -65,6 +69,19 @@ class QueueBase : public QueueInterface { int32 capacity() const { return capacity_; } + bool closed() { + mutex_lock lock(mu_); + return closed_; + } + + // Copies the index^th slice (in the first dimension) of parent into element. + static Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64 index); + + // Copies element into the index^th slice (in the first dimension) of parent. + static Status CopyElementToSlice(const Tensor& element, Tensor* parent, + int64 index); + protected: enum Action { kEnqueue, kDequeue }; enum RunResult { kNoProgress, kProgress, kComplete }; @@ -98,14 +115,6 @@ class QueueBase : public QueueInterface { return shape; } - // Copies the index^th slice (in the first dimension) of parent into element. - static Status CopySliceToElement(const Tensor& parent, Tensor* element, - int64 index); - - // Copies element into the index^th slice (in the first dimension) of parent. - static Status CopyElementToSlice(const Tensor& element, Tensor* parent, - int64 index); - void Cancel(Action action, CancellationManager* cancellation_manager, CancellationToken token); diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index ac2a1d6e64..1f2952e194 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -226,7 +226,8 @@ class DequeueManyOp : public QueueAccessOpKernel { callback); queue->TryDequeueMany( - num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + num_elements, ctx, false /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { if (!ctx->status().ok()) { callback(); return; @@ -251,6 +252,83 @@ class DequeueManyOp : public QueueAccessOpKernel { REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), DequeueManyOp); +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +class DequeueUpToOp : public QueueAccessOpKernel { + public: + explicit DequeueUpToOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat<int32>()(0); + + OP_REQUIRES_ASYNC( + ctx, num_elements >= 0, + errors::InvalidArgument("DequeueUpToOp must request a positive number " + "of elements"), + callback); + + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + + queue->TryDequeueMany( + num_elements, ctx, true /* allow_small_batch */, + [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), + callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); + } + + ~DequeueUpToOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU), + DequeueUpToOp); + // Defines a QueueCloseOp, which closes the given Queue. Closing a // Queue signals that no more elements will be enqueued in it. // diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index 1c96d6234c..a002328869 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -53,6 +53,7 @@ class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > { DoneCallback callback) override; void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) override; Status MatchesNodeDef(const NodeDef& node_def) override; @@ -256,7 +257,15 @@ void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx, } void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, CallbackWithTuple callback) { + if (allow_small_batch) { + ctx->SetStatus( + errors::Unimplemented("Dequeue: Queue does not support small batches")); + callback(Tuple()); + return; + } + if (!specified_shapes()) { ctx->SetStatus( errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the " |