aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h29
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/fifo_queue_test.py201
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/framework/resource_op_kernel.h25
-rw-r--r--tensorflow/core/kernels/BUILD5
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc15
-rw-r--r--tensorflow/core/kernels/fifo_queue.h23
-rw-r--r--tensorflow/core/kernels/fifo_queue_op.cc39
-rw-r--r--tensorflow/core/kernels/queue_op.cc367
-rw-r--r--tensorflow/core/kernels/queue_op.h233
-rw-r--r--tensorflow/core/kernels/queue_ops.cc395
13 files changed, 883 insertions, 466 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index d976f8296c..c2245b8eae 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -176,9 +176,11 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
+ "//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 11e45d2823..a605335a94 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -23,9 +23,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
@@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
- LoopCondOp);
+ LoopCondOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
+ REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
+ .Device(DEVICE) \
+ .HostMemory("size") \
+ .HostMemory("handle"), \
+ QueueSizeOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
+ QueueIsClosedOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
+
+// TODO(phawkins): currently we do not register the QueueEnqueueMany,
+// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
+// and write the tensors they access in order to concatenate them into a batch.
+// We would need either to call out to an XLA computation to perform the
+// concatenation, or we would need to refactor those kernels so the splitting
+// or merging is done in a separate operator that can be compiled.
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c1f65416b4..366822f0b7 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -372,6 +372,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "fifo_queue_test",
+ size = "medium",
+ srcs = ["fifo_queue_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:extra_py_tests_deps",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
new file mode 100644
index 0000000000..0f64cc87cd
--- /dev/null
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -0,0 +1,201 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.platform import test
+
+
+class FIFOQueueTest(xla_test.XLATestCase):
+
+ def testEnqueue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ enqueue_op.run()
+
+ def testEnqueueWithShape(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
+ enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+ enqueue_correct_op.run()
+ with self.assertRaises(ValueError):
+ q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+ self.assertEqual(1, q.size().eval())
+
+ def testMultipleDequeues(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue([1]))
+ self.evaluate(q.enqueue([2]))
+ self.evaluate(q.enqueue([3]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ def testQueuesDontShare(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
+
+ def testEnqueueDictWithoutNames(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ with self.assertRaisesRegexp(ValueError, "must have names"):
+ q.enqueue({"a": 12.0})
+
+ def testParallelEnqueue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Run one producer thread for each element in elems.
+ def enqueue(enqueue_op):
+ sess.run(enqueue_op)
+
+ threads = [
+ self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
+ ]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # Dequeue every element using a single thread.
+ results = []
+ for _ in xrange(len(elems)):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testParallelDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Enqueue every element using a single thread.
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ # Run one consumer thread for each element in elems.
+ results = []
+
+ def dequeue():
+ results.append(sess.run(dequeued_t))
+
+ threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, results)
+
+ def testDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ vals = dequeued_t.eval()
+ self.assertEqual([elems[i]], vals)
+
+ def testEnqueueAndBlockingDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ # The enqueue_ops should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for enqueue_op in enqueue_ops:
+ sess.run(enqueue_op)
+
+ results = []
+
+ def dequeue():
+ for _ in xrange(len(elems)):
+ results.append(sess.run(dequeued_t))
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ for elem, result in zip(elems, results):
+ self.assertEqual([elem], result)
+
+ def testMultiEnqueueAndDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
+ elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+ enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ x_val, y_val = sess.run(dequeued_t)
+ x, y = elems[i]
+ self.assertEqual([x], x_val)
+ self.assertEqual([y], y_val)
+
+ def testQueueSizeEmpty(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ self.assertEqual([0], q.size().eval())
+
+ def testQueueSizeAfterEnqueueAndDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue()
+ size = q.size()
+ self.assertEqual([], size.get_shape())
+
+ enqueue_op.run()
+ self.assertEqual(1, size.eval())
+ dequeued_t.op.run()
+ self.assertEqual(0, size.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 89db9ee279..6e7423f85e 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/queue_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 813ec6eed5..0a8da8b3bf 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -43,9 +43,15 @@ template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &handle_, nullptr));
+ has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
+ if (!has_resource_type_) {
+ // The resource variant of the op may be placed on non-CPU devices, but
+ // this allocation is always on the host. Fortunately we don't need it in
+ // the resource case.
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &handle_, nullptr));
+ }
}
// The resource is deleted from the resource manager only when it is private
@@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
return;
}
- auto h = handle_.AccessTensor(context)->template flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
+ if (!has_resource_type_) {
+ auto h = handle_.AccessTensor(context)->template flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ }
resource_ = resource;
}
- if (context->expected_output_dtype(0) == DT_RESOURCE) {
+ if (has_resource_type_) {
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<T>()));
@@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
virtual Status VerifyResource(T* resource) { return Status::OK(); }
PersistentTensor handle_ GUARDED_BY(mu_);
+
+ // Is the output of the operator of type DT_RESOURCE?
+ bool has_resource_type_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cbe30cdca1..861fb1ef69 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -368,6 +368,7 @@ cc_library(
cc_library(
name = "queue_op",
+ srcs = ["queue_op.cc"],
hdrs = ["queue_op.h"],
deps = [
":queue_base",
@@ -1885,9 +1886,10 @@ cc_library(
name = "fifo_queue",
srcs = ["fifo_queue.cc"],
hdrs = ["fifo_queue.h"],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
":queue_base",
+ ":queue_op",
":typed_queue",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -5076,6 +5078,7 @@ filegroup(
"padding_fifo_queue.cc",
"padding_fifo_queue_op.cc",
"queue_base.cc",
+ "queue_op.cc",
"queue_ops.cc",
"random_op.cc",
"reduction_ops_all.cc",
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index a23478af5b..d6e859f1aa 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
+ : TypedQueueOp(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+}
+
+Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
+ FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
+ component_shapes_, cinfo_.name());
+ return CreateTypedQueue(queue, ret);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index f01d70924d..697ee81c39 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
-#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
#include <deque>
#include <vector>
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
};
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class FIFOQueueOp : public TypedQueueOp {
+ public:
+ explicit FIFOQueueOp(OpKernelConstruction* context);
+
+ private:
+ Status CreateResource(QueueInterface** ret) override
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ std::vector<TensorShape> component_shapes_;
+ TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc
index b35bdbb2f0..80869768f1 100644
--- a/tensorflow/core/kernels/fifo_queue_op.cc
+++ b/tensorflow/core/kernels/fifo_queue_op.cc
@@ -13,50 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// See docs in ../ops/data_flow_ops.cc.
-
-#include <deque>
-#include <vector>
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fifo_queue.h"
-#include "tensorflow/core/kernels/queue_base.h"
-#include "tensorflow/core/kernels/queue_op.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-// Defines a FIFOQueueOp, which produces a Queue (specifically, one
-// backed by FIFOQueue) that persists across different graph
-// executions, and sessions. Running this op produces a single-element
-// tensor of handles to Queues in the corresponding device.
-class FIFOQueueOp : public TypedQueueOp {
- public:
- explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
- OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
- }
-
- private:
- Status CreateResource(QueueInterface** ret) override
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
- component_shapes_, cinfo_.name());
- return CreateTypedQueue(queue, ret);
- }
-
- std::vector<TensorShape> component_shapes_;
- TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc
new file mode 100644
index 0000000000..53f431ef3c
--- /dev/null
+++ b/tensorflow/core/kernels/queue_op.cc
@@ -0,0 +1,367 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/queue_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ if (capacity_ < 0) {
+ capacity_ = QueueBase::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+}
+
+void QueueOp::Compute(OpKernelContext* context) {
+ ResourceOpKernel<QueueInterface>::Compute(context);
+ mutex_lock l(mu_);
+ if (resource_ && context->track_allocations()) {
+ context->record_persistent_memory_allocation(resource_->MemoryUsed());
+ }
+}
+
+Status QueueOp::VerifyResource(QueueInterface* queue) {
+ return queue->MatchesNodeDef(def());
+}
+
+
+QueueOpKernel::QueueOpKernel(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) {
+ QueueInterface* queue;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
+ callback);
+ }
+ ComputeAsync(ctx, queue, [callback, queue]() {
+ queue->Unref();
+ callback();
+ });
+}
+
+QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
+ // TODO(keveman): Enable timeout.
+ OP_REQUIRES(context, timeout_ == -1,
+ errors::InvalidArgument("Timeout not supported yet."));
+}
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+EnqueueOp::EnqueueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ DataTypeVector expected_inputs;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ expected_inputs.push_back(DT_RESOURCE);
+ } else {
+ expected_inputs.push_back(DT_STRING_REF);
+ }
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
+ queue->TryEnqueue(tuple, ctx, callback);
+}
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ DataTypeVector expected_inputs;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ expected_inputs.push_back(DT_RESOURCE);
+ } else {
+ expected_inputs.push_back(DT_STRING_REF);
+ }
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
+ queue->TryEnqueueMany(tuple, ctx, callback);
+}
+
+EnqueueManyOp::~EnqueueManyOp() = default;
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+DequeueOp::DequeueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueOp::~DequeueOp() = default;
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+DequeueManyOp::DequeueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueManyOp requested ",
+ num_elements, " < 0 elements"),
+ callback);
+
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx,
+ ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeueMany(
+ num_elements, ctx, false /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueManyOp::~DequeueManyOp() = default;
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueUpToOp requested ",
+ num_elements, " < 0 elements"),
+ callback);
+
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx,
+ ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeueMany(
+ num_elements, ctx, true /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueUpToOp::~DequeueUpToOp() = default;
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+QueueCloseOp::QueueCloseOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
+ &cancel_pending_enqueues_));
+}
+
+void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ queue->Close(ctx, cancel_pending_enqueues_, callback);
+}
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+QueueSizeOp::QueueSizeOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ Tensor* Tqueue_size = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
+ Tqueue_size->flat<int32>().setConstant(queue->size());
+ callback();
+}
+
+QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ Tensor* Tqueue_is_closed = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
+ Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
+ callback();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h
index 6c19f9841c..2efd838a5f 100644
--- a/tensorflow/core/kernels/queue_op.h
+++ b/tensorflow/core/kernels/queue_op.h
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
-#define TENSORFLOW_KERNELS_QUEUE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
#include <deque>
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -32,22 +33,9 @@ namespace tensorflow {
// Defines a QueueOp, an abstract class for Queue construction ops.
class QueueOp : public ResourceOpKernel<QueueInterface> {
public:
- QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
- if (capacity_ < 0) {
- capacity_ = QueueBase::kUnbounded;
- }
- OP_REQUIRES_OK(context,
- context->GetAttr("component_types", &component_types_));
- }
+ QueueOp(OpKernelConstruction* context);
- void Compute(OpKernelContext* context) override {
- ResourceOpKernel<QueueInterface>::Compute(context);
- mutex_lock l(mu_);
- if (resource_ && context->track_allocations()) {
- context->record_persistent_memory_allocation(resource_->MemoryUsed());
- }
- }
+ void Compute(OpKernelContext* context) override;
protected:
// Variables accessible by subclasses
@@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
DataTypeVector component_types_;
private:
- Status VerifyResource(QueueInterface* queue) override {
- return queue->MatchesNodeDef(def());
- }
+ Status VerifyResource(QueueInterface* queue) override;
};
class TypedQueueOp : public QueueOp {
@@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp {
}
};
+// Queue manipulator kernels
+
+class QueueOpKernel : public AsyncOpKernel {
+ public:
+ explicit QueueOpKernel(OpKernelConstruction* context);
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final;
+
+ protected:
+ virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) = 0;
+};
+
+class QueueAccessOpKernel : public QueueOpKernel {
+ public:
+ explicit QueueAccessOpKernel(OpKernelConstruction* context);
+
+ protected:
+ int64 timeout_;
+};
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+class EnqueueOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
+};
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+class EnqueueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueManyOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~EnqueueManyOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
+};
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+class DequeueOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
+};
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+class DequeueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueManyOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueManyOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
+};
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+class DequeueUpToOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueUpToOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueUpToOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
+};
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+class QueueCloseOp : public QueueOpKernel {
+ public:
+ explicit QueueCloseOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ bool cancel_pending_enqueues_;
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
+};
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+class QueueSizeOp : public QueueOpKernel {
+ public:
+ explicit QueueSizeOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
+};
+
+class QueueIsClosedOp : public QueueOpKernel {
+ public:
+ explicit QueueIsClosedOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index 46a02854d7..c4d404259b 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -13,437 +13,44 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// See docs in ../ops/data_flow_ops.cc.
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-class QueueOpKernel : public AsyncOpKernel {
- public:
- explicit QueueOpKernel(OpKernelConstruction* context)
- : AsyncOpKernel(context) {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
- QueueInterface* queue;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
- callback);
- }
- ComputeAsync(ctx, queue, [callback, queue]() {
- queue->Unref();
- callback();
- });
- }
-
- protected:
- virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) = 0;
-};
-
-class QueueAccessOpKernel : public QueueOpKernel {
- public:
- explicit QueueAccessOpKernel(OpKernelConstruction* context)
- : QueueOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
- // TODO(keveman): Enable timeout.
- OP_REQUIRES(context, timeout_ == -1,
- errors::InvalidArgument("Timeout not supported yet."));
- }
-
- protected:
- int64 timeout_;
-};
-
-// Defines an EnqueueOp, the execution of which enqueues a tuple of
-// tensors in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-class EnqueueOp : public QueueAccessOpKernel {
- public:
- explicit EnqueueOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- DataTypeVector expected_inputs;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- expected_inputs.push_back(DT_RESOURCE);
- } else {
- expected_inputs.push_back(DT_STRING_REF);
- }
- for (DataType dt : queue->component_dtypes()) {
- expected_inputs.push_back(dt);
- }
- OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
- callback);
-
- QueueInterface::Tuple tuple;
- OpInputList components;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
- callback);
- for (const Tensor& Tcomponent : components) {
- tuple.push_back(Tcomponent);
- }
-
- OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
- queue->TryEnqueue(tuple, ctx, callback);
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
-// Defines an EnqueueManyOp, the execution of which slices each
-// component of a tuple of tensors along the 0th dimension, and
-// enqueues tuples of slices in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-//
-// N.B. All tuple components must have the same size in the 0th
-// dimension.
-class EnqueueManyOp : public QueueAccessOpKernel {
- public:
- explicit EnqueueManyOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- DataTypeVector expected_inputs;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- expected_inputs.push_back(DT_RESOURCE);
- } else {
- expected_inputs.push_back(DT_STRING_REF);
- }
- for (DataType dt : queue->component_dtypes()) {
- expected_inputs.push_back(dt);
- }
- OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
- callback);
-
- QueueInterface::Tuple tuple;
- OpInputList components;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
- callback);
- for (const Tensor& Tcomponent : components) {
- tuple.push_back(Tcomponent);
- }
-
- OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
- queue->TryEnqueueMany(tuple, ctx, callback);
- }
-
- ~EnqueueManyOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
EnqueueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
EnqueueManyOp);
-// Defines a DequeueOp, the execution of which dequeues a tuple of
-// tensors from the given Queue.
-//
-// The op has one input, which is the handle of the appropriate
-// Queue. The op has k outputs, where k is the number of components in
-// the tuples stored in the given Queue, and output i is the ith
-// component of the dequeued tuple.
-class DequeueOp : public QueueAccessOpKernel {
- public:
- explicit DequeueOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components), callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
-// Defines a DequeueManyOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-class DequeueManyOp : public QueueAccessOpKernel {
- public:
- explicit DequeueManyOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- const Tensor& Tnum_elements = ctx->input(1);
- int32 num_elements = Tnum_elements.flat<int32>()(0);
-
- OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
- errors::InvalidArgument("DequeueManyOp requested ",
- num_elements, " < 0 elements"),
- callback);
-
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_RESOURCE, DT_INT32},
- queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_STRING_REF, DT_INT32},
- queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeueMany(
- num_elements, ctx, false /* allow_small_batch */,
- [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components),
- callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueManyOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
DequeueManyOp);
-// Defines a DequeueUpToOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The difference between this op and DequeueMany is the handling when
-// the Queue is closed. While the DequeueMany op will return if there
-// an error when there are less than num_elements elements left in the
-// closed queue, this op will return between 1 and
-// min(num_elements, elements_remaining_in_queue), and will not block.
-// If there are no elements left, then the standard DequeueMany error
-// is returned.
-//
-// This op only works if the underlying Queue implementation accepts
-// the allow_small_batch = true parameter to TryDequeueMany.
-// If it does not, an errors::Unimplemented exception is returned.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-//
-// The op has one attribute: allow_small_batch. If the Queue supports
-// it, setting this to true causes the queue to return smaller
-// (possibly zero length) batches when it is closed, up to however
-// many elements are available when the op executes. In this case,
-// the Queue does not block when closed.
-class DequeueUpToOp : public QueueAccessOpKernel {
- public:
- explicit DequeueUpToOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- const Tensor& Tnum_elements = ctx->input(1);
- int32 num_elements = Tnum_elements.flat<int32>()(0);
-
- OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
- errors::InvalidArgument("DequeueUpToOp requested ",
- num_elements, " < 0 elements"),
- callback);
-
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_RESOURCE, DT_INT32},
- queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_STRING_REF, DT_INT32},
- queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeueMany(
- num_elements, ctx, true /* allow_small_batch */,
- [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components),
- callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueUpToOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
DequeueUpToOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
DequeueUpToOp);
-// Defines a QueueCloseOp, which closes the given Queue. Closing a
-// Queue signals that no more elements will be enqueued in it.
-//
-// The op has one input, which is the handle of the appropriate Queue.
-class QueueCloseOp : public QueueOpKernel {
- public:
- explicit QueueCloseOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
- &cancel_pending_enqueues_));
- }
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- queue->Close(ctx, cancel_pending_enqueues_, callback);
- }
-
- private:
- bool cancel_pending_enqueues_;
- TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
-// Defines a QueueSizeOp, which computes the number of elements in the
-// given Queue, and emits it as an output tensor.
-//
-// The op has one input, which is the handle of the appropriate Queue;
-// and one output, which is a single-element tensor containing the current
-// size of that Queue.
-class QueueSizeOp : public QueueOpKernel {
- public:
- explicit QueueSizeOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- Tensor* Tqueue_size = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
- Tqueue_size->flat<int32>().setConstant(queue->size());
- callback();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
-class QueueIsClosedOp : public QueueOpKernel {
- public:
- explicit QueueIsClosedOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- Tensor* Tqueue_is_closed = nullptr;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
- Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
- callback();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
QueueIsClosedOp);
REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),