diff options
Diffstat (limited to 'tensorflow/core/kernels/random_shuffle_queue_op.cc')
-rw-r--r-- | tensorflow/core/kernels/random_shuffle_queue_op.cc | 57 |
1 files changed, 7 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index da391224af..e9b6ead381 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/typed_queue.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/random/philox_random.h" @@ -404,17 +405,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { // backed by RandomShuffleQueue) that persists across different graph // executions, and sessions. Running this op produces a single-element // tensor of handles to Queues in the corresponding device. -class RandomShuffleQueueOp : public OpKernel { +class RandomShuffleQueueOp : public QueueOp { public: explicit RandomShuffleQueueOp(OpKernelConstruction* context) - : OpKernel(context), queue_handle_set_(false) { - OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_)); - OP_REQUIRES_OK(context, - context->allocate_persistent(DT_STRING, TensorShape({2}), - &queue_handle_, nullptr)); - if (capacity_ < 0) { - capacity_ = RandomShuffleQueue::kUnbounded; - } + : QueueOp(context) { OP_REQUIRES_OK(context, context->GetAttr("min_after_dequeue", &min_after_dequeue_)); OP_REQUIRES(context, min_after_dequeue_ >= 0, @@ -427,32 +421,12 @@ class RandomShuffleQueueOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); - OP_REQUIRES_OK(context, - context->GetAttr("component_types", &component_types_)); OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); } - ~RandomShuffleQueueOp() override { - // If the queue object was not shared, delete it. - if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) { - TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>( - cinfo_.container(), cinfo_.name())); - } - } - - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - if (!queue_handle_set_) { - OP_REQUIRES_OK(ctx, SetQueueHandle(ctx)); - } - ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx)); - } - - private: - Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); - QueueInterface* queue; - auto creator = [this](QueueInterface** ret) { + protected: + CreatorCallback GetCreator() const override { + return [this](QueueInterface** ret) { auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_, seed2_, component_types_, component_shapes_, cinfo_.name()); @@ -464,30 +438,13 @@ class RandomShuffleQueueOp : public OpKernel { } return s; }; - TF_RETURN_IF_ERROR( - cinfo_.resource_manager()->LookupOrCreate<QueueInterface>( - cinfo_.container(), cinfo_.name(), &queue, creator)); - core::ScopedUnref unref_me(queue); - // Verify that the shared queue is compatible with the requested arguments. - TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def())); - auto h = queue_handle_.AccessTensor(ctx)->flat<string>(); - h(0) = cinfo_.container(); - h(1) = cinfo_.name(); - queue_handle_set_ = true; - return Status::OK(); } - int32 capacity_; + private: int32 min_after_dequeue_; int64 seed_; int64 seed2_; - DataTypeVector component_types_; std::vector<TensorShape> component_shapes_; - ContainerInfo cinfo_; - - mutex mu_; - PersistentTensor queue_handle_ GUARDED_BY(mu_); - bool queue_handle_set_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp); }; |