aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/fifo_queue.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/fifo_queue.h')
-rw-r--r--tensorflow/core/kernels/fifo_queue.h127
1 files changed, 127 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
new file mode 100644
index 0000000000..e9fe5f34a4
--- /dev/null
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -0,0 +1,127 @@
+#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class FIFOQueue : public QueueBase {
+ public:
+ FIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name);
+ Status Initialize(); // Must be called before any other method.
+
+ // Implementations of QueueInterface methods --------------------------------
+
+ Status ValidateTuple(const Tuple& tuple) override;
+ Status ValidateManyTuple(const Tuple& tuple) override;
+ void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
+ void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) override;
+ void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) override;
+ Status MatchesNodeDef(const NodeDef& node_def) override;
+
+ int32 size() override {
+ mutex_lock lock(mu_);
+ return queues_[0].size();
+ }
+
+ int32 capacity() const { return capacity_; }
+
+ private:
+ enum Action { kEnqueue, kDequeue };
+
+ ~FIFOQueue() override {}
+
+ TensorShape ManyOutShape(int i, int64 batch_size) {
+ TensorShape shape({batch_size});
+ shape.AppendShape(component_shapes_[i]);
+ return shape;
+ }
+
+ // Helper for dequeuing a single element from queues_.
+ void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ void Cancel(Action action, CancellationToken token);
+
+ // Helper for cancelling all pending Enqueue(Many) operations when
+ // Close is called with cancel_pending_enqueues.
+ void CloseAndCancel();
+
+ // Tries to enqueue/dequeue (or close) based on whatever is at the
+ // front of enqueue_attempts_/dequeue_attempts_. Appends to
+ // *finished the callback for any finished attempt (so it may be
+ // called once mu_ is released). Returns true if any progress was
+ // made.
+ struct CleanUp {
+ CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
+ : finished(f), to_deregister(ct), cm(cm) {}
+ DoneCallback finished;
+ CancellationToken to_deregister;
+ CancellationManager* cm;
+ };
+ bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Tries to make progress on the enqueues or dequeues at the front
+ // of the *_attempts_ queues.
+ void FlushUnlocked();
+
+ const int32 capacity_;
+
+ mutex mu_;
+ typedef std::deque<PersistentTensor> SubQueue;
+ std::vector<SubQueue> queues_ GUARDED_BY(mu_);
+ bool closed_ GUARDED_BY(mu_);
+
+ enum RunResult { kNoProgress, kProgress, kComplete };
+ struct Attempt;
+ typedef std::function<RunResult(Attempt*)> RunCallback;
+ struct Attempt {
+ int32 elements_requested;
+ DoneCallback done_callback; // must be run outside mu_
+ OpKernelContext* context;
+ CancellationToken cancellation_token;
+ RunCallback run_callback; // must be run while holding mu_
+ bool is_cancelled;
+ Tuple tuple;
+
+ Attempt(int32 elements_requested, DoneCallback done_callback,
+ OpKernelContext* context, CancellationToken cancellation_token,
+ RunCallback run_callback)
+ : elements_requested(elements_requested),
+ done_callback(done_callback),
+ context(context),
+ cancellation_token(cancellation_token),
+ run_callback(run_callback),
+ is_cancelled(false) {}
+ };
+ std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
+ std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
+
+ static Status GetElementComponentFromBatch(const Tuple& tuple, int index,
+ int component,
+ OpKernelContext* ctx,
+ PersistentTensor* out_element);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_