aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_shuffle_queue_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/random_shuffle_queue_op.cc')
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc740
1 files changed, 740 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
new file mode 100644
index 0000000000..561ec76e53
--- /dev/null
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -0,0 +1,740 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class RandomShuffleQueue : public QueueBase {
+ public:
+ RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed,
+ int64 seed2, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name);
+ Status Initialize(); // Must be called before any other method.
+
+ // Implementations of QueueInterface methods --------------------------------
+
+ Status ValidateTuple(const Tuple& tuple) override;
+ Status ValidateManyTuple(const Tuple& tuple) override;
+ void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
+ void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) override;
+ void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) override;
+ Status MatchesNodeDef(const NodeDef& node_def) override;
+
+ int32 size() override {
+ mutex_lock lock(mu_);
+ return queues_[0].size();
+ }
+
+ private:
+ enum Action { kEnqueue, kDequeue };
+
+ ~RandomShuffleQueue() override {}
+
+ TensorShape ManyOutShape(int i, int batch_size) {
+ TensorShape shape({batch_size});
+ shape.AppendShape(component_shapes_[i]);
+ return shape;
+ }
+
+ // Helper for dequeuing a single random element from queues_.
+ void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ void Cancel(Action action, CancellationToken token);
+
+ // Helper for cancelling all pending Enqueue(Many) operations when
+ // Close is called with cancel_pending_enqueues.
+ void CloseAndCancel();
+
+ // Tries to enqueue/dequeue (or close) based on whatever is at the
+ // front of enqueue_attempts_/dequeue_attempts_. Appends to
+ // *finished the callback for any finished attempt (so it may be
+ // called once mu_ is released). Returns true if any progress was
+ // made.
+ struct CleanUp {
+ CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
+ : finished(f), to_deregister(ct), cm(cm) {}
+ DoneCallback finished;
+ CancellationToken to_deregister;
+ CancellationManager* cm;
+ };
+ bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Tries to make progress on the enqueues or dequeues at the front
+ // of the *_attempts_ queues.
+ void FlushUnlocked();
+
+ const int32 capacity_;
+ const int32 min_after_dequeue_;
+ const int64 original_seed_;
+ const int64 original_seed2_;
+
+ mutex mu_;
+ typedef std::vector<PersistentTensor> SubQueue;
+ std::vector<SubQueue> queues_ GUARDED_BY(mu_);
+ bool closed_ GUARDED_BY(mu_);
+ random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
+ random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
+
+ enum RunResult { kNoProgress, kProgress, kComplete };
+ struct Attempt;
+ typedef std::function<RunResult(Attempt*)> RunCallback;
+ struct Attempt {
+ int32 elements_requested;
+ DoneCallback done_callback; // must be run outside mu_
+ OpKernelContext* context;
+ CancellationToken cancellation_token;
+ RunCallback run_callback; // must be run while holding mu_
+ bool is_cancelled;
+ Tuple tuple;
+
+ Attempt(int32 elements_requested, DoneCallback done_callback,
+ OpKernelContext* context, CancellationToken cancellation_token,
+ RunCallback run_callback)
+ : elements_requested(elements_requested),
+ done_callback(done_callback),
+ context(context),
+ cancellation_token(cancellation_token),
+ run_callback(run_callback),
+ is_cancelled(false) {}
+ };
+ std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
+ std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
+};
+
+RandomShuffleQueue::RandomShuffleQueue(
+ int capacity, int min_after_dequeue, int64 seed, int64 seed2,
+ const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes, const string& name)
+ : QueueBase(component_dtypes, component_shapes, name),
+ capacity_(capacity),
+ min_after_dequeue_(min_after_dequeue),
+ original_seed_(seed),
+ original_seed2_(seed2),
+ closed_(false),
+ generator_(&parent_generator_) {
+ if (seed == 0 && seed2 == 0) {
+ // If both seeds are unspecified, use completely random seeds.
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+ parent_generator_ = random::PhiloxRandom(seed, seed2);
+}
+
+Status RandomShuffleQueue::Initialize() {
+ if (component_dtypes_.empty()) {
+ return errors::InvalidArgument("Empty component types for queue ", name_);
+ }
+ if (!component_shapes_.empty() &&
+ component_dtypes_.size() != component_shapes_.size()) {
+ return errors::InvalidArgument("Different number of component types (",
+ component_dtypes_.size(), ") vs. shapes (",
+ component_shapes_.size(), ").");
+ }
+
+ mutex_lock lock(mu_);
+ queues_.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ queues_.push_back(SubQueue());
+ queues_.back().reserve(min_after_dequeue_);
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status RandomShuffleQueue::ValidateTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ component_shapes_[i].ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status RandomShuffleQueue::ValidateManyTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ // Expected shape is [batch_size] + component_shapes_[i]
+ const TensorShape expected_shape = ManyOutShape(i, batch_size);
+ if (!tuple[i].shape().IsSameSize(expected_shape)) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ expected_shape.ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ } else {
+ for (size_t i = 1; i < tuple.size(); ++i) {
+ if (tuple[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All input tensors must have the same size in the 0th ",
+ "dimension. Component ", i, " has ", tuple[i].dim_size(0),
+ ", and should have ", batch_size);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
+ DCHECK_GT(queues_[0].size(), 0);
+ int64 index = generator_() % queues_[0].size();
+ (*tuple).reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ (*tuple).push_back(*queues_[i][index].AccessTensor(ctx));
+ queues_[i][index] = queues_[i].back();
+ queues_[i].pop_back();
+ }
+}
+
+void RandomShuffleQueue::Cancel(Action action, CancellationToken token) {
+ DoneCallback callback = nullptr;
+ {
+ mutex_lock lock(mu_);
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ for (Attempt& attempt : *attempts) {
+ if (attempt.cancellation_token == token) {
+ attempt.is_cancelled = true;
+ if (action == kEnqueue) {
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ } else {
+ attempt.context->SetStatus(
+ errors::Cancelled("Dequeue operation was cancelled"));
+ }
+ std::swap(callback, attempt.done_callback);
+ break;
+ }
+ }
+ }
+ if (callback) {
+ callback();
+ FlushUnlocked();
+ }
+}
+
+void RandomShuffleQueue::CloseAndCancel() {
+ std::vector<DoneCallback> callbacks;
+ {
+ mutex_lock lock(mu_);
+ closed_ = true;
+ for (Attempt& attempt : enqueue_attempts_) {
+ attempt.is_cancelled = true;
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ callbacks.emplace_back(std::move(attempt.done_callback));
+ }
+ }
+ for (const DoneCallback& callback : callbacks) {
+ callback();
+ }
+ FlushUnlocked();
+}
+
+bool RandomShuffleQueue::TryAttemptLocked(
+ Action action, std::vector<CleanUp>* clean_up) {
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ bool progress = false;
+ bool done = false;
+ while (!done && !attempts->empty()) {
+ if (attempts->front().is_cancelled) {
+ if (action == kEnqueue) {
+ LOG(INFO) << "Skipping cancelled enqueue attempt";
+ } else {
+ LOG(INFO) << "Skipping cancelled dequeue attempt";
+ }
+ attempts->pop_front();
+ } else {
+ Attempt* cur_attempt = &attempts->front();
+ switch (cur_attempt->run_callback(cur_attempt)) {
+ case kNoProgress:
+ done = true;
+ break;
+ case kProgress:
+ done = true;
+ progress = true;
+ break;
+ case kComplete:
+ progress = true;
+ clean_up->emplace_back(std::move(cur_attempt->done_callback),
+ cur_attempt->cancellation_token,
+ cur_attempt->context->cancellation_manager());
+ attempts->pop_front();
+ break;
+ }
+ }
+ }
+ return progress;
+}
+
+void RandomShuffleQueue::FlushUnlocked() {
+ std::vector<CleanUp> clean_up;
+ Ref();
+ {
+ mutex_lock lock(mu_);
+ bool changed;
+ do {
+ changed = TryAttemptLocked(kEnqueue, &clean_up);
+ changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
+ } while (changed);
+ }
+ Unref();
+ for (const auto& to_clean : clean_up) {
+ if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
+ // NOTE(mrry): We can safely ignore the return value of
+ // DeregisterCallback because the mutex mu_ ensures that the
+ // cleanup action only executes once.
+ to_clean.cm->DeregisterCallback(to_clean.to_deregister);
+ }
+ to_clean.finished();
+ }
+}
+
+void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) {
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kEnqueue, token); });
+ if (!already_cancelled) {
+ enqueue_attempts_.emplace_back(
+ 1, callback, ctx, token,
+ [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "RandomShuffleQueue '", name_, "' is closed."));
+ return kComplete;
+ }
+ if (queues_[0].size() < static_cast<size_t>(capacity_)) {
+ for (int i = 0; i < num_components(); ++i) {
+ queues_[i].push_back(PersistentTensor(tuple[i]));
+ }
+ return kComplete;
+ } else {
+ return kNoProgress;
+ }
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
+ callback();
+ }
+}
+
+void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple,
+ OpKernelContext* ctx,
+ DoneCallback callback) {
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (batch_size == 0) {
+ callback();
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kEnqueue, token); });
+ if (!already_cancelled) {
+ enqueue_attempts_.emplace_back(
+ batch_size, callback, ctx, token,
+ [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "RandomShuffleQueue '", name_, "' is closed."));
+ return kComplete;
+ }
+ RunResult result = kNoProgress;
+ while (queues_[0].size() < static_cast<size_t>(capacity_)) {
+ result = kProgress;
+ const int index =
+ tuple[0].dim_size(0) - attempt->elements_requested;
+ for (int i = 0; i < num_components(); ++i) {
+ TensorShape element_shape(tuple[i].shape());
+ element_shape.RemoveDim(0);
+ PersistentTensor element;
+ Tensor* element_access = nullptr;
+ attempt->context->allocate_persistent(
+ tuple[i].dtype(), element_shape, &element, &element_access);
+ attempt->context->SetStatus(
+ CopySliceToElement(tuple[i], element_access, index));
+ if (!attempt->context->status().ok()) return kComplete;
+ queues_[i].push_back(element);
+ }
+ --attempt->elements_requested;
+ if (attempt->elements_requested == 0) {
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
+ callback();
+ }
+}
+
+void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
+ CallbackWithTuple callback) {
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ 1, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int32 s = queues_[0].size();
+ if (closed_ && s == 0) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "RandomShuffleQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ", 1, ", current size ", s,
+ ")"));
+ return kComplete;
+ }
+ if (!closed_) s -= min_after_dequeue_;
+ if (s > 0) {
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ attempt->done_callback = [callback, tuple]() { callback(tuple); };
+ return kComplete;
+ } else {
+ return kNoProgress;
+ }
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) {
+ if (!specified_shapes()) {
+ ctx->SetStatus(
+ errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the "
+ "components to have specified shapes."));
+ callback(Tuple());
+ return;
+ }
+ if (num_elements == 0) {
+ Tuple tuple;
+ tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ // TODO(josh11b,misard): Switch to allocate_output(). Problem is
+ // this breaks the abstraction boundary since we don't *really*
+ // know if and how the Tensors in the tuple we pass to callback
+ // correspond to the outputs of *ctx. For example, the
+ // ReaderRead Op uses TryDequeue() to get a filename out of a
+ // queue that is used internally by the reader and is not
+ // associated with any output of the ReaderRead.
+ // mrry@ adds:
+ // Maybe we need to pass a std::function<Tensor*(...)> (or
+ // better signature) that calls the appropriate allocator
+ // function in addition to ctx? (Or support a shim Allocator
+ // that has an internal OpKernelContext*, and dispatches to the
+ // appropriate method?)
+ // misard@ adds:
+ // I don't see that a std::function would help. The problem is
+ // that at this point (allocation time) the system doesn't know
+ // what is going to happen to the element read out of the
+ // queue. As long as we keep the generality that TensorFlow Ops
+ // do their own dynamic allocation in arbitrary C++ code, we
+ // need to preserve robustness to allocating output Tensors with
+ // the 'wrong' attributes, and fixing up with a copy. The only
+ // improvement I can see here in the future would be to support
+ // an optimized case where the queue 'knows' what attributes to
+ // use, and plumbs them through here.
+ Tensor element;
+ ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element);
+ tuple.emplace_back(element);
+ }
+ callback(tuple);
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ num_elements, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int32 s = queues_[0].size();
+ if (closed_ && s < attempt->elements_requested) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "RandomSuffleQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ",
+ attempt->elements_requested, ", current size ", s, ")"));
+ return kComplete;
+ }
+
+ RunResult result = kNoProgress;
+ if (!closed_) s -= min_after_dequeue_;
+ for (; s > 0; --s) {
+ if (attempt->tuple.empty()) {
+ // Only allocate tuple when we have something to dequeue
+ // so we don't use exceessive memory when there are many
+ // blocked dequeue attempts waiting.
+ attempt->tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ const TensorShape shape =
+ ManyOutShape(i, attempt->elements_requested);
+ Tensor element;
+ attempt->context->allocate_temp(component_dtypes_[i], shape,
+ &element);
+ attempt->tuple.emplace_back(element);
+ }
+ }
+ result = kProgress;
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ const int index =
+ attempt->tuple[0].dim_size(0) - attempt->elements_requested;
+ for (int i = 0; i < num_components(); ++i) {
+ attempt->context->SetStatus(
+ CopyElementToSlice(tuple[i], &attempt->tuple[i], index));
+ if (!attempt->context->status().ok()) return kComplete;
+ }
+ tuple.clear();
+ --attempt->elements_requested;
+ if (attempt->elements_requested == 0) {
+ tuple = attempt->tuple;
+ attempt->done_callback = [callback, tuple]() {
+ callback(tuple);
+ };
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+void RandomShuffleQueue::Close(OpKernelContext* ctx,
+ bool cancel_pending_enqueues,
+ DoneCallback callback) {
+ if (cancel_pending_enqueues) {
+ CloseAndCancel();
+ callback();
+ } else {
+ {
+ mutex_lock lock(mu_);
+ enqueue_attempts_.emplace_back(
+ 0, callback, ctx, CancellationManager::kInvalidToken,
+ [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "RandomShuffleQueue '", name_, "' is already closed."));
+ } else {
+ closed_ = true;
+ }
+ return kComplete;
+ });
+ }
+ FlushUnlocked();
+ }
+}
+
+Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
+ TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue"));
+ TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
+
+ int32 min_after_dequeue = -1;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node_def, "min_after_dequeue", &min_after_dequeue));
+ if (min_after_dequeue != min_after_dequeue_) {
+ return errors::InvalidArgument(
+ "Shared queue '", name_, "' has min_after_dequeue ",
+ min_after_dequeue_, " but requested min_after_dequeue was ",
+ min_after_dequeue, ".");
+ }
+
+ int64 seed = -1;
+ int64 seed2 = -1;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed", &seed));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2", &seed2));
+ if ((seed != 0 || seed2 != 0) &&
+ (seed != original_seed_ || seed2 != original_seed2_)) {
+ return errors::InvalidArgument(
+ "Shared queue '", name_, "' has random seeds (", original_seed_, ", ",
+ original_seed2_, ") but requested seeds are (", seed, ", ", seed2,
+ ").");
+ }
+
+ TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
+ TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def));
+
+ return Status::OK();
+}
+
+typedef std::shared_ptr<QueueInterface> QueueInterfacePtr;
+
+// Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
+// backed by RandomShuffleQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class RandomShuffleQueueOp : public OpKernel {
+ public:
+ explicit RandomShuffleQueueOp(OpKernelConstruction* context)
+ : OpKernel(context), queue_handle_set_(false) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &queue_handle_, nullptr));
+ if (capacity_ < 0) {
+ capacity_ = RandomShuffleQueue::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("min_after_dequeue", &min_after_dequeue_));
+ OP_REQUIRES(context, min_after_dequeue_ >= 0,
+ errors::InvalidArgument("min_after_dequeue ",
+ min_after_dequeue_, " must be >= 0"));
+ OP_REQUIRES(
+ context, min_after_dequeue_ < capacity_,
+ errors::InvalidArgument("min_after_dequeue ", min_after_dequeue_,
+ " must be < capacity ", capacity_));
+ OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
+ OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
+
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+ }
+
+ ~RandomShuffleQueueOp() override {
+ // If the queue object was not shared, delete it.
+ if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ if (!queue_handle_set_) {
+ OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
+ }
+ ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
+ }
+
+ private:
+ Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
+ QueueInterface* queue;
+ auto creator = [this](QueueInterface** ret) {
+ auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_,
+ seed2_, component_types_,
+ component_shapes_, cinfo_.name());
+ Status s = q->Initialize();
+ if (s.ok()) {
+ *ret = q;
+ } else {
+ q->Unref();
+ }
+ return s;
+ };
+ TF_RETURN_IF_ERROR(
+ cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
+ cinfo_.container(), cinfo_.name(), &queue, creator));
+ core::ScopedUnref unref_me(queue);
+ // Verify that the shared queue is compatible with the requested arguments.
+ TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
+ auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ queue_handle_set_ = true;
+ return Status::OK();
+ }
+
+ int32 capacity_;
+ int32 min_after_dequeue_;
+ int64 seed_;
+ int64 seed2_;
+ DataTypeVector component_types_;
+ std::vector<TensorShape> component_shapes_;
+ ContainerInfo cinfo_;
+
+ mutex mu_;
+ PersistentTensor queue_handle_ GUARDED_BY(mu_);
+ bool queue_handle_set_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU),
+ RandomShuffleQueueOp);
+
+} // namespace tensorflow