// See docs in ../ops/data_flow_ops.cc. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/queue_interface.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/port.h" #include "tensorflow/core/public/tensor.h" #include "tensorflow/core/public/tensor_shape.h" namespace tensorflow { class QueueOpKernel : public AsyncOpKernel { public: explicit QueueOpKernel(OpKernelConstruction* context) : AsyncOpKernel(context) {} void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { QueueInterface* queue; OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue), callback); ComputeAsync(ctx, queue, [callback, queue]() { queue->Unref(); callback(); }); } protected: virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) = 0; }; class QueueAccessOpKernel : public QueueOpKernel { public: explicit QueueAccessOpKernel(OpKernelConstruction* context) : QueueOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); // TODO(keveman): Enable timeout. OP_REQUIRES(context, timeout_ == -1, errors::InvalidArgument("Timeout not supported yet.")); } protected: int64 timeout_; }; // Defines an EnqueueOp, the execution of which enqueues a tuple of // tensors in the given Queue. // // The op has 1 + k inputs, where k is the number of components in the // tuples stored in the given Queue: // - Input 0: queue handle. // - Input 1: 0th element of the tuple. // - ... // - Input (1+k): kth element of the tuple. class EnqueueOp : public QueueAccessOpKernel { public: explicit EnqueueOp(OpKernelConstruction* context) : QueueAccessOpKernel(context) {} protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { DataTypeVector expected_inputs = {DT_STRING_REF}; for (DataType dt : queue->component_dtypes()) { expected_inputs.push_back(dt); } OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback); QueueInterface::Tuple tuple; OpInputList components; OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), callback); for (const Tensor& Tcomponent : components) { tuple.push_back(Tcomponent); } OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback); queue->TryEnqueue(tuple, ctx, callback); } private: TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp); }; REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp); // Defines an EnqueueManyOp, the execution of which slices each // component of a tuple of tensors along the 0th dimension, and // enqueues tuples of slices in the given Queue. // // The op has 1 + k inputs, where k is the number of components in the // tuples stored in the given Queue: // - Input 0: queue handle. // - Input 1: 0th element of the tuple. // - ... // - Input (1+k): kth element of the tuple. // // N.B. All tuple components must have the same size in the 0th // dimension. class EnqueueManyOp : public QueueAccessOpKernel { public: explicit EnqueueManyOp(OpKernelConstruction* context) : QueueAccessOpKernel(context) {} protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { DataTypeVector expected_inputs = {DT_STRING_REF}; for (DataType dt : queue->component_dtypes()) { expected_inputs.push_back(dt); } OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); QueueInterface::Tuple tuple; OpInputList components; OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components), callback); for (const Tensor& Tcomponent : components) { tuple.push_back(Tcomponent); } OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback); queue->TryEnqueueMany(tuple, ctx, callback); } ~EnqueueManyOp() override {} private: TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp); }; REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU), EnqueueManyOp); // Defines a DequeueOp, the execution of which dequeues a tuple of // tensors from the given Queue. // // The op has one input, which is the handle of the appropriate // Queue. The op has k outputs, where k is the number of components in // the tuples stored in the given Queue, and output i is the ith // component of the dequeued tuple. class DequeueOp : public QueueAccessOpKernel { public: explicit DequeueOp(OpKernelConstruction* context) : QueueAccessOpKernel(context) {} protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { OP_REQUIRES_OK_ASYNC( ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()), callback); queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { if (!ctx->status().ok()) { callback(); return; } OpOutputList output_components; OP_REQUIRES_OK_ASYNC( ctx, ctx->output_list("components", &output_components), callback); for (int i = 0; i < ctx->num_outputs(); ++i) { output_components.set(i, tuple[i]); } callback(); }); } ~DequeueOp() override {} private: TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp); }; REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp); // Defines a DequeueManyOp, the execution of which concatenates the // requested number of elements from the given Queue along the 0th // dimension, and emits the result as a single tuple of tensors. // // The op has two inputs: // - Input 0: the handle to a queue. // - Input 1: the number of elements to dequeue. // // The op has k outputs, where k is the number of components in the // tuples stored in the given Queue, and output i is the ith component // of the dequeued tuple. class DequeueManyOp : public QueueAccessOpKernel { public: explicit DequeueManyOp(OpKernelConstruction* context) : QueueAccessOpKernel(context) {} protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { const Tensor& Tnum_elements = ctx->input(1); int32 num_elements = Tnum_elements.flat()(0); OP_REQUIRES_ASYNC( ctx, num_elements >= 0, errors::InvalidArgument("DequeueManyOp must request a positive number " "of elements"), callback); OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32}, queue->component_dtypes()), callback); queue->TryDequeueMany( num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) { if (!ctx->status().ok()) { callback(); return; } OpOutputList output_components; OP_REQUIRES_OK_ASYNC( ctx, ctx->output_list("components", &output_components), callback); for (int i = 0; i < ctx->num_outputs(); ++i) { output_components.set(i, tuple[i]); } callback(); }); } ~DequeueManyOp() override {} private: TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp); }; REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU), DequeueManyOp); // Defines a QueueCloseOp, which closes the given Queue. Closing a // Queue signals that no more elements will be enqueued in it. // // The op has one input, which is the handle of the appropriate Queue. class QueueCloseOp : public QueueOpKernel { public: explicit QueueCloseOp(OpKernelConstruction* context) : QueueOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", &cancel_pending_enqueues_)); } protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { queue->Close(ctx, cancel_pending_enqueues_, callback); } private: bool cancel_pending_enqueues_; TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp); }; REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp); // Defines a QueueSizeOp, which computes the number of elements in the // given Queue, and emits it as an output tensor. // // The op has one input, which is the handle of the appropriate Queue; // and one output, which is a single-element tensor containing the current // size of that Queue. class QueueSizeOp : public QueueOpKernel { public: explicit QueueSizeOp(OpKernelConstruction* context) : QueueOpKernel(context) {} protected: void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, DoneCallback callback) override { Tensor* Tqueue_size = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size)); Tqueue_size->flat().setConstant(queue->size()); callback(); } private: TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp); }; REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp); } // namespace tensorflow