aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_shuffle_queue_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/random_shuffle_queue_op.cc')
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc57
1 files changed, 7 insertions, 50 deletions
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index da391224af..e9b6ead381 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/philox_random.h"
@@ -404,17 +405,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
// backed by RandomShuffleQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
-class RandomShuffleQueueOp : public OpKernel {
+class RandomShuffleQueueOp : public QueueOp {
public:
explicit RandomShuffleQueueOp(OpKernelConstruction* context)
- : OpKernel(context), queue_handle_set_(false) {
- OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &queue_handle_, nullptr));
- if (capacity_ < 0) {
- capacity_ = RandomShuffleQueue::kUnbounded;
- }
+ : QueueOp(context) {
OP_REQUIRES_OK(context,
context->GetAttr("min_after_dequeue", &min_after_dequeue_));
OP_REQUIRES(context, min_after_dequeue_ >= 0,
@@ -427,32 +421,12 @@ class RandomShuffleQueueOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
- OP_REQUIRES_OK(context,
- context->GetAttr("component_types", &component_types_));
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
- ~RandomShuffleQueueOp() override {
- // If the queue object was not shared, delete it.
- if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
- TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
- cinfo_.container(), cinfo_.name()));
- }
- }
-
- void Compute(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
- if (!queue_handle_set_) {
- OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
- }
- ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
- }
-
- private:
- Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
- QueueInterface* queue;
- auto creator = [this](QueueInterface** ret) {
+ protected:
+ CreatorCallback GetCreator() const override {
+ return [this](QueueInterface** ret) {
auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_,
seed2_, component_types_,
component_shapes_, cinfo_.name());
@@ -464,30 +438,13 @@ class RandomShuffleQueueOp : public OpKernel {
}
return s;
};
- TF_RETURN_IF_ERROR(
- cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
- cinfo_.container(), cinfo_.name(), &queue, creator));
- core::ScopedUnref unref_me(queue);
- // Verify that the shared queue is compatible with the requested arguments.
- TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
- auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
- queue_handle_set_ = true;
- return Status::OK();
}
- int32 capacity_;
+ private:
int32 min_after_dequeue_;
int64 seed_;
int64 seed2_;
- DataTypeVector component_types_;
std::vector<TensorShape> component_shapes_;
- ContainerInfo cinfo_;
-
- mutex mu_;
- PersistentTensor queue_handle_ GUARDED_BY(mu_);
- bool queue_handle_set_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp);
};