aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc8
-rw-r--r--tensorflow/core/kernels/fifo_queue.h1
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.cc8
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.h1
-rw-r--r--tensorflow/core/kernels/queue_base.h25
-rw-r--r--tensorflow/core/kernels/queue_ops.cc80
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc9
7 files changed, 123 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index d7102e0c98..d58f750a67 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -190,7 +190,15 @@ void FIFOQueue::TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) {
}
void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (!specified_shapes()) {
ctx->SetStatus(
errors::InvalidArgument("FIFOQueue's DequeueMany requires the "
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index 1804f7cd26..c2c8cd57be 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -44,6 +44,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
DoneCallback callback) override;
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc
index eba4677b6d..ec52dc8ce6 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.cc
+++ b/tensorflow/core/kernels/padding_fifo_queue.cc
@@ -65,7 +65,15 @@ Status PaddingFIFOQueue::GetElementComponent(
}
void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (num_elements == 0) {
Tuple tuple;
tuple.reserve(num_components());
diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h
index 23b8b2b198..eb0937a096 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.h
+++ b/tensorflow/core/kernels/padding_fifo_queue.h
@@ -43,6 +43,7 @@ class PaddingFIFOQueue : public FIFOQueue {
// Implementations of QueueInterface methods --------------------------------
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 078c405902..83bffbe868 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -31,6 +31,10 @@ limitations under the License.
namespace tensorflow {
+namespace barrier {
+class Barrier;
+} // namespace barrier
+
// Functionality common to asynchronous QueueInterface implementations.
class QueueBase : public QueueInterface {
public:
@@ -65,6 +69,19 @@ class QueueBase : public QueueInterface {
int32 capacity() const { return capacity_; }
+ bool closed() {
+ mutex_lock lock(mu_);
+ return closed_;
+ }
+
+ // Copies the index^th slice (in the first dimension) of parent into element.
+ static Status CopySliceToElement(const Tensor& parent, Tensor* element,
+ int64 index);
+
+ // Copies element into the index^th slice (in the first dimension) of parent.
+ static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
+ int64 index);
+
protected:
enum Action { kEnqueue, kDequeue };
enum RunResult { kNoProgress, kProgress, kComplete };
@@ -98,14 +115,6 @@ class QueueBase : public QueueInterface {
return shape;
}
- // Copies the index^th slice (in the first dimension) of parent into element.
- static Status CopySliceToElement(const Tensor& parent, Tensor* element,
- int64 index);
-
- // Copies element into the index^th slice (in the first dimension) of parent.
- static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
- int64 index);
-
void Cancel(Action action, CancellationManager* cancellation_manager,
CancellationToken token);
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index ac2a1d6e64..1f2952e194 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -226,7 +226,8 @@ class DequeueManyOp : public QueueAccessOpKernel {
callback);
queue->TryDequeueMany(
- num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ num_elements, ctx, false /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
if (!ctx->status().ok()) {
callback();
return;
@@ -251,6 +252,83 @@ class DequeueManyOp : public QueueAccessOpKernel {
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+class DequeueUpToOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueUpToOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(
+ ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueUpToOp must request a positive number "
+ "of elements"),
+ callback);
+
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+
+ queue->TryDequeueMany(
+ num_elements, ctx, true /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components),
+ callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+ }
+
+ ~DequeueUpToOp() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
+ DequeueUpToOp);
+
// Defines a QueueCloseOp, which closes the given Queue. Closing a
// Queue signals that no more elements will be enqueued in it.
//
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index 1c96d6234c..a002328869 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -53,6 +53,7 @@ class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
DoneCallback callback) override;
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
@@ -256,7 +257,15 @@ void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
}
void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (!specified_shapes()) {
ctx->SetStatus(
errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the "