diff options
Diffstat (limited to 'tensorflow/core/kernels/queue_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/queue_ops.cc | 288 |
1 files changed, 288 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc new file mode 100644 index 0000000000..c70dc76777 --- /dev/null +++ b/tensorflow/core/kernels/queue_ops.cc @@ -0,0 +1,288 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class QueueOpKernel : public AsyncOpKernel { + public: + explicit QueueOpKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { + QueueInterface* queue; + OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), + callback); + ComputeAsync(ctx, queue, [callback, queue]() { + queue->Unref(); + callback(); + }); + } + + protected: + virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) = 0; +}; + +class QueueAccessOpKernel : public QueueOpKernel { + public: + explicit QueueAccessOpKernel(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); + // TODO(keveman): Enable timeout. + OP_REQUIRES(context, timeout_ == -1, + errors::InvalidArgument("Timeout not supported yet.")); + } + + protected: + int64 timeout_; +}; + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +class EnqueueOp : public QueueAccessOpKernel { + public: + explicit EnqueueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + DataTypeVector expected_inputs = {DT_STRING_REF}; + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), + callback); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); + queue->TryEnqueue(tuple, ctx, callback); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +class EnqueueManyOp : public QueueAccessOpKernel { + public: + explicit EnqueueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + DataTypeVector expected_inputs = {DT_STRING_REF}; + for (DataType dt : queue->component_dtypes()) { + expected_inputs.push_back(dt); + } + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + + QueueInterface::Tuple tuple; + OpInputList components; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), + callback); + for (const Tensor& Tcomponent : components) { + tuple.push_back(Tcomponent); + } + + OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); + queue->TryEnqueueMany(tuple, ctx, callback); + } + + ~EnqueueManyOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), + EnqueueManyOp); + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +class DequeueOp : public QueueAccessOpKernel { + public: + explicit DequeueOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), + callback); + + queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); + } + + ~DequeueOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +class DequeueManyOp : public QueueAccessOpKernel { + public: + explicit DequeueManyOp(OpKernelConstruction* context) + : QueueAccessOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + const Tensor& Tnum_elements = ctx->input(1); + int32 num_elements = Tnum_elements.flat<int32>()(0); + + OP_REQUIRES_ASYNC( + ctx, num_elements >= 0, + errors::InvalidArgument("DequeueManyOp must request a positive number " + "of elements"), + callback); + + OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, + queue->component_dtypes()), + callback); + + queue->TryDequeueMany( + num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { + if (!ctx->status().ok()) { + callback(); + return; + } + OpOutputList output_components; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->output_list("components", &output_components), + callback); + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_components.set(i, tuple[i]); + } + callback(); + }); + } + + ~DequeueManyOp() override {} + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), + DequeueManyOp); + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +class QueueCloseOp : public QueueOpKernel { + public: + explicit QueueCloseOp(OpKernelConstruction* context) + : QueueOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", + &cancel_pending_enqueues_)); + } + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + queue->Close(ctx, cancel_pending_enqueues_, callback); + } + + private: + bool cancel_pending_enqueues_; + TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +class QueueSizeOp : public QueueOpKernel { + public: + explicit QueueSizeOp(OpKernelConstruction* context) + : QueueOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override { + Tensor* Tqueue_size = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); + Tqueue_size->flat<int32>().setConstant(queue->size()); + callback(); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); +}; + +REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); + +} // namespace tensorflow |