aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-01-19 09:39:34 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-20 07:47:54 -0800
commite39629219e748b08177f2c457ba45d51f5370aae (patch)
treeb0861c366bbbdf8330bb1862aa5a32251c012e65 /tensorflow/core/kernels
parentf592f23775e2a6ac75496829db5005d3bb70a3d2 (diff)
PaddingFIFOQueue is like FIFOQueue but allows dynamic shapes (using padding with DequeueMany)
Change: 112482056
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/fifo_queue.h3
-rw-r--r--tensorflow/core/kernels/fifo_queue_op.cc59
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.cc370
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.h88
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue_op.cc68
-rw-r--r--tensorflow/core/kernels/queue_base.cc19
-rw-r--r--tensorflow/core/kernels/queue_base.h2
-rw-r--r--tensorflow/core/kernels/queue_op.h105
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc57
9 files changed, 659 insertions, 112 deletions
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index c8406e93bb..35ac38777e 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -52,7 +52,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
return queues_[0].size();
}
- private:
+ protected:
~FIFOQueue() override {}
// Helper for dequeuing a single element from queues_.
@@ -64,6 +64,7 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
OpKernelContext* ctx,
PersistentTensor* out_element);
+ private:
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
};
diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc
index 663250359c..a43c17637c 100644
--- a/tensorflow/core/kernels/fifo_queue_op.cc
+++ b/tensorflow/core/kernels/fifo_queue_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -38,70 +39,24 @@ namespace tensorflow {
// backed by FIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
-class FIFOQueueOp : public OpKernel {
+class FIFOQueueOp : public QueueOp {
public:
- explicit FIFOQueueOp(OpKernelConstruction* context)
- : OpKernel(context), queue_handle_set_(false) {
- OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &queue_handle_, nullptr));
- if (capacity_ < 0) {
- capacity_ = FIFOQueue::kUnbounded;
- }
- OP_REQUIRES_OK(context,
- context->GetAttr("component_types", &component_types_));
+ explicit FIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
- ~FIFOQueueOp() override {
- // If the queue object was not shared, delete it.
- if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
- TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
- cinfo_.container(), cinfo_.name()));
- }
- }
-
- void Compute(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
- if (!queue_handle_set_) {
- OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
- }
- ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
- }
-
- private:
- Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
- QueueInterface* queue;
- auto creator = [this](QueueInterface** ret) {
+ protected:
+ CreatorCallback GetCreator() const override {
+ return [this](QueueInterface** ret) {
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
component_shapes_, cinfo_.name());
*ret = queue;
return queue->Initialize();
};
- TF_RETURN_IF_ERROR(
- cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
- cinfo_.container(), cinfo_.name(), &queue, creator));
- core::ScopedUnref unref_me(queue);
- // Verify that the shared queue is compatible with the requested arguments.
- TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
- auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
- queue_handle_set_ = true;
- return Status::OK();
}
- int32 capacity_;
- DataTypeVector component_types_;
+ private:
std::vector<TensorShape> component_shapes_;
- ContainerInfo cinfo_;
-
- mutex mu_;
- PersistentTensor queue_handle_ GUARDED_BY(mu_);
- bool queue_handle_set_ GUARDED_BY(mu_);
-
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
};
diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc
new file mode 100644
index 0000000000..d2f8c06fdb
--- /dev/null
+++ b/tensorflow/core/kernels/padding_fifo_queue.cc
@@ -0,0 +1,370 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/padding_fifo_queue.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+PaddingFIFOQueue::PaddingFIFOQueue(
+ int capacity, const DataTypeVector& component_dtypes,
+ const std::vector<PartialTensorShape>& partial_shapes, const string& name)
+ : FIFOQueue(capacity, component_dtypes,
+ ConvertShapesPartialDimensionsToZero(partial_shapes), name),
+ partial_shapes_(partial_shapes) {}
+
+Status PaddingFIFOQueue::Initialize() {
+ Status s = FIFOQueue::Initialize();
+ if (!s.ok()) return s;
+
+ if (component_dtypes_.size() != partial_shapes_.size()) {
+ return errors::InvalidArgument(
+ "Shapes must be provided for all components, but received ",
+ component_dtypes_.size(), " dtypes and ", partial_shapes_.size(),
+ " shapes.");
+ }
+
+ return Status::OK();
+}
+
+/* static */
+Status PaddingFIFOQueue::GetElementComponent(
+ const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx,
+ PersistentTensor* out_tensor) {
+ TensorShape element_shape(tuple[component].shape());
+ Tensor* element_access = nullptr;
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ tuple[component].dtype(), element_shape, out_tensor, &element_access));
+ *element_access = tuple[component];
+ return Status::OK();
+}
+
+void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) {
+ if (num_elements == 0) {
+ Tuple tuple;
+ tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ // TODO(josh11b,misard): Switch to allocate_output().
+ // See similar comment in fifo_queue.cc
+ Tensor element;
+ // Here, ManyOutShape returns zeros for undetermined shapes,
+ // which is exactly what we want to use.
+ ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element);
+ tuple.emplace_back(element);
+ }
+ callback(tuple);
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ num_elements, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int32 s = queues_[0].size();
+ if (closed_ && s < attempt->elements_requested) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "PaddingFIFOQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ",
+ attempt->elements_requested, ", current size ", s, ")"));
+
+ // TODO(mrry): Add support for producing a partial batch as
+ // output when the queue is closed.
+ if (!attempt->tuples.empty()) {
+ // Restore already-dequeued elements to the front of the queue.
+ for (int64 i = attempt->tuples.size() - 1; i >= 0; --i) {
+ for (int j = 0; j < num_components(); ++j) {
+ PersistentTensor element;
+ Status s = GetElementComponent(attempt->tuples[i], j,
+ attempt->context, &element);
+ if (!s.ok()) {
+ attempt->context->SetStatus(
+ errors::DataLoss("Failed to restore element from "
+ "partially-dequeued batch "
+ "to PaddingFIFOQueue: ",
+ s.error_message()));
+ }
+ queues_[j].push_front(element);
+ }
+ }
+ }
+ return kComplete;
+ }
+
+ RunResult result = kNoProgress;
+ for (; s > 0; --s) {
+ result = kProgress;
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ attempt->tuples.push_back(tuple);
+ tuple.clear();
+ --attempt->elements_requested;
+
+ if (attempt->elements_requested == 0) {
+ // Finished. Allocate attempt->tuple and
+ // copy from attempt->tuples to attempt->tuple.
+ attempt->tuple.reserve(num_components());
+ const std::vector<Tuple>& tuples = attempt->tuples;
+
+ std::vector<bool> dynamic_shape;
+ const int64 batch_size = tuples.size();
+
+ for (int i = 0; i < num_components(); ++i) {
+ const PartialTensorShape partial_shape =
+ PartialTensorShape({batch_size})
+ .Concatenate(partial_shapes_[i]);
+ TensorShape shape({batch_size});
+
+ for (int j = 0; j < partial_shape.dims() - 1; ++j) {
+ if (partial_shape.dim_size(j + 1) > -1) {
+ shape.AddDim(partial_shape.dim_size(j + 1));
+ } else {
+ // Expand sizes to match.
+ int64 max_val = 0;
+ for (const Tuple& t : tuples) {
+ max_val = max(max_val, t[i].shape().dim_size(j));
+ }
+ shape.AddDim(max_val);
+ }
+ }
+
+ Tensor element;
+ attempt->context->allocate_temp(component_dtypes_[i], shape,
+ &element);
+
+ bool has_dynamic_shape = !partial_shape.IsFullyDefined();
+ if (has_dynamic_shape) {
+ // Set all values to zero because not all values
+ // will get written over.
+ attempt->context->SetStatus(SetElementZero(&element));
+ if (!attempt->context->status().ok()) return kComplete;
+ }
+
+ dynamic_shape.push_back(has_dynamic_shape);
+
+ // TODO(ebrevdo): should this be a persistent tensor?
+ attempt->tuple.emplace_back(element);
+ }
+
+ for (int index = 0; index < tuples.size(); ++index) {
+ for (int i = 0; i < num_components(); ++i) {
+ if (dynamic_shape[i]) {
+ // Slightly slower copy operation
+ attempt->context->SetStatus(CopyElementToLargerSlice(
+ tuples[index][i], &attempt->tuple[i], index));
+ } else {
+ attempt->context->SetStatus(CopyElementToSlice(
+ tuples[index][i], &attempt->tuple[i], index));
+ }
+ if (!attempt->context->status().ok()) return kComplete;
+ }
+ }
+ tuple = attempt->tuple;
+ attempt->tuples.clear();
+ attempt->done_callback = [callback, tuple]() {
+ callback(tuple);
+ };
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) {
+ return errors::InvalidArgument("Shape mismatch in tuple component ", i,
+ ". Expected ",
+ partial_shapes_[i].DebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ return Status::OK();
+}
+
+Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ const int64 batch_size = tuple[0].dim_size(0);
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ // Expected shape is [batch_size] + partial_shapes_[i]
+ const PartialTensorShape expected_shape =
+ PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]);
+ if (!expected_shape.IsCompatibleWith(tuple[i].shape())) {
+ return errors::InvalidArgument("Shape mismatch in tuple component ", i,
+ ". Expected ",
+ expected_shape.DebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ return Status::OK();
+}
+
+Status PaddingFIFOQueue::CompatibleNodeDefShapes(
+ const NodeDef& node_def) const {
+ std::vector<PartialTensorShape> requested_shapes;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
+ if (!PartialTensorShapeUtils::AreCompatible(requested_shapes,
+ partial_shapes_)) {
+ return errors::InvalidArgument(
+ "Shared queue '", name_, "' has component shapes ",
+ PartialTensorShapeUtils::PartialShapeListString(partial_shapes_),
+ " but requested component shapes were ",
+ PartialTensorShapeUtils::PartialShapeListString(requested_shapes));
+ } else {
+ return Status::OK();
+ }
+}
+
+Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
+ TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "PaddingFIFOQueue"));
+ TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
+ TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
+ TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
+ return Status::OK();
+}
+
+template <typename T, int NDIMS>
+Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
+ int index) {
+ DCHECK_NE(parent->dim_size(0), 0);
+ if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
+ TensorShape chip_shape = parent->shape();
+ chip_shape.RemoveDim(0);
+ return errors::Internal(
+ "HandleElementToLargerSlice Cannot copy slice: number of entries in "
+ "element is greater than number of elements in parent slice. ",
+ "Shapes are: [element]: ", element.shape().DebugString(),
+ ", [parent slice]: ", chip_shape.DebugString());
+ }
+ auto element_t = element.tensor<T, NDIMS>();
+ auto parent_t = parent->tensor<T, NDIMS + 1>();
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
+ slice_indices[0] = index;
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
+ slice_size[0] = 1;
+ for (int i = 1; i < slice_size.size(); ++i) {
+ slice_size[i] = element_t.dimension(i - 1);
+ }
+ parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
+ return Status::OK();
+}
+
+namespace {
+
+template <int NDIMS>
+Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
+ int index) {
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
+ }
+
+ switch (element.dtype()) {
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ return errors::Unimplemented(
+ "HandleElementToLargerSliceWithRank Unhandled data type: ",
+ element.dtype());
+ }
+}
+
+} // namespace
+
+Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element,
+ Tensor* parent, int index) {
+ if (parent->dims() != element.dims() + 1) {
+ return errors::Internal(
+ "Mismatched ranks. Element's rank is: ", element.dims(),
+ " but element is meant to be a slice in output Tensor having rank: ",
+ parent->dims(), " (should be: ", element.dims() + 1, ")");
+ }
+
+#define HANDLE_DIMS(NDIMS) \
+ case NDIMS: { \
+ TF_RETURN_IF_ERROR( \
+ HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
+ return Status::OK(); \
+ }
+
+ switch (element.dims()) {
+ HANDLE_DIMS(0);
+ HANDLE_DIMS(1);
+ HANDLE_DIMS(2);
+ HANDLE_DIMS(3);
+ HANDLE_DIMS(4);
+#undef HANDLE_DIMS
+ default:
+ return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
+ element.dims());
+ }
+}
+
+// Static method
+Status PaddingFIFOQueue::SetElementZero(Tensor* element) {
+#define HANDLE_TYPE(T) \
+ if (element->dtype() == DataTypeToEnum<T>::value) { \
+ element->flat<T>().setConstant(T()); \
+ return Status::OK(); \
+ }
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("SetElementZero Unhandled data type: ",
+ element->dtype());
+}
+
+std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero(
+ const gtl::ArraySlice<PartialTensorShape>& partial_shapes) {
+ std::vector<TensorShape> shapes(partial_shapes.size());
+ for (int i = 0; i < shapes.size(); ++i) {
+ const PartialTensorShape& partial = partial_shapes[i];
+ TensorShape& shape = shapes[i];
+ for (int64 s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s);
+ }
+ return shapes;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h
new file mode 100644
index 0000000000..afcbeea7e8
--- /dev/null
+++ b/tensorflow/core/kernels/padding_fifo_queue.h
@@ -0,0 +1,88 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
+#define TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
+#include "tensorflow/core/kernels/typed_queue.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/partial_tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class PaddingFIFOQueue : public FIFOQueue {
+ public:
+ PaddingFIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
+ const std::vector<PartialTensorShape>& component_shapes,
+ const string& name);
+
+ Status Initialize() override;
+
+ // Implementations of QueueInterface methods --------------------------------
+
+ void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) override;
+ Status MatchesNodeDef(const NodeDef& node_def) override;
+
+ protected:
+ Status ValidateManyTuple(const Tuple& tuple) override;
+ Status ValidateTuple(const Tuple& tuple) override;
+ Status CompatibleNodeDefShapes(const NodeDef& node_def) const;
+
+ // Convert a list of PartialTensorShape to a list of
+ // TensorShape.
+ // Any unknown dimension sizes are converted to 0.
+ // REQUIRED: All the input shapes have well defined rank.
+ static std::vector<TensorShape> ConvertShapesPartialDimensionsToZero(
+ const gtl::ArraySlice<PartialTensorShape>& partial_shapes);
+
+ // Sets the values in the given element to zero.
+ static Status SetElementZero(Tensor* element);
+
+ // Copies element into the index^th slice (in the first dimension)
+ // of parent. Allows for the parent's slice to have a larger size
+ // than the element, and copies the element into the upper left hand
+ // corner of the slice.
+ static Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
+ int index);
+
+ std::vector<PartialTensorShape> partial_shapes_;
+
+ private:
+ ~PaddingFIFOQueue() override {}
+
+ static Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple,
+ int component, OpKernelContext* ctx,
+ PersistentTensor* out_tensor);
+
+ static Status IsSameSizeExceptZerosInFirst(const TensorShape& first,
+ const TensorShape& second);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueue);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/padding_fifo_queue_op.cc b/tensorflow/core/kernels/padding_fifo_queue_op.cc
new file mode 100644
index 0000000000..08347bcf63
--- /dev/null
+++ b/tensorflow/core/kernels/padding_fifo_queue_op.cc
@@ -0,0 +1,68 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/padding_fifo_queue.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/kernels/queue_op.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/partial_tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Defines a PaddingFIFOQueueOp, which produces a Queue (specifically, one
+// backed by PaddingFIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class PaddingFIFOQueueOp : public QueueOp {
+ public:
+ explicit PaddingFIFOQueueOp(OpKernelConstruction* context) : QueueOp(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+ }
+
+ protected:
+ CreatorCallback GetCreator() const override {
+ return [this](QueueInterface** ret) {
+ PaddingFIFOQueue* queue = new PaddingFIFOQueue(
+ capacity_, component_types_, component_shapes_, cinfo_.name());
+ *ret = queue;
+ return queue->Initialize();
+ };
+ }
+
+ private:
+ std::vector<PartialTensorShape> component_shapes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueueOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("PaddingFIFOQueue").Device(DEVICE_CPU),
+ PaddingFIFOQueueOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 590e4e9123..b33af06408 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -32,8 +32,8 @@ Status HandleSliceToElement(const Tensor& parent, Tensor* element, int index) {
TensorShape chip_shape = parent.shape();
chip_shape.RemoveDim(0);
return errors::Internal(
- "Cannot copy slice: number of elements does not match. Shapes are: "
- "[element]: ",
+ "HandleSliceToElement Cannot copy slice: number of elements does not "
+ "match. Shapes are: [element]: ",
element->shape().DebugString(), ", [parent slice]: ",
chip_shape.DebugString());
}
@@ -50,8 +50,8 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
TensorShape chip_shape = parent->shape();
chip_shape.RemoveDim(0);
return errors::Internal(
- "Cannot copy slice: number of elements does not match. Shapes are: "
- "[element]: ",
+ "HandleElementToSlice Cannot copy slice: number of elements does not "
+ "match. Shapes are: [element]: ",
element.shape().DebugString(), ", [parent slice]: ",
chip_shape.DebugString());
}
@@ -156,7 +156,7 @@ Status QueueBase::ValidateTuple(const Tuple& tuple) {
TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
if (specified_shapes()) {
for (size_t i = 0; i < tuple.size(); ++i) {
- if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+ if (!component_shapes_[i].IsSameSize(tuple[i].shape())) {
return errors::InvalidArgument(
"Shape mismatch in tuple component ", i, ". Expected ",
component_shapes_[i].ShortDebugString(), ", got ",
@@ -176,7 +176,7 @@ Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
for (size_t i = 0; i < tuple.size(); ++i) {
// Expected shape is [batch_size] + component_shapes_[i]
const TensorShape expected_shape = ManyOutShape(i, batch_size);
- if (!tuple[i].shape().IsSameSize(expected_shape)) {
+ if (!expected_shape.IsSameSize(tuple[i].shape())) {
return errors::InvalidArgument(
"Shape mismatch in tuple component ", i, ". Expected ",
expected_shape.ShortDebugString(), ", got ",
@@ -331,7 +331,6 @@ void QueueBase::FlushUnlocked() {
}
}
-// Static method
Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
int index) {
#define HANDLE_TYPE(DT) \
@@ -355,7 +354,8 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
HANDLE_TYPE(DT_QINT16);
HANDLE_TYPE(DT_QUINT16);
#undef HANDLE_TYPE
- return errors::Unimplemented("Unhandled data type: ", parent.dtype());
+ return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
+ parent.dtype());
}
// Static method
@@ -382,7 +382,8 @@ Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
HANDLE_TYPE(DT_QINT16);
HANDLE_TYPE(DT_QUINT16);
#undef HANDLE_TYPE
- return errors::Unimplemented("Unhandled data type: ", element.dtype());
+ return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
+ element.dtype());
}
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 10a6d65417..1c0e377cda 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -146,6 +146,8 @@ class QueueBase : public QueueInterface {
RunCallback run_callback; // must be run while holding mu_
bool is_cancelled;
Tuple tuple;
+ // tuples is used by some implementations allowing dynamic shapes.
+ std::vector<Tuple> tuples;
Attempt(int32 elements_requested, DoneCallback done_callback,
OpKernelContext* context, CancellationToken cancellation_token,
diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h
new file mode 100644
index 0000000000..a549439220
--- /dev/null
+++ b/tensorflow/core/kernels/queue_op.h
@@ -0,0 +1,105 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
+#define TENSORFLOW_KERNELS_QUEUE_OP_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Defines a QueueOp, an abstract class for Queue construction ops.
+class QueueOp : public OpKernel {
+ public:
+ QueueOp(OpKernelConstruction* context)
+ : OpKernel(context), queue_handle_set_(false) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &queue_handle_, nullptr));
+ if (capacity_ < 0) {
+ capacity_ = QueueBase::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ if (!queue_handle_set_) {
+ OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
+ }
+ ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
+ }
+
+ protected:
+ ~QueueOp() override {
+ // If the queue object was not shared, delete it.
+ if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+ }
+
+ protected:
+ typedef std::function<Status(QueueInterface**)> CreatorCallback;
+
+ // Subclasses must override this
+ virtual CreatorCallback GetCreator() const = 0;
+
+ // Variables accessible by subclasses
+ int32 capacity_;
+ DataTypeVector component_types_;
+ ContainerInfo cinfo_;
+
+ private:
+ Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
+ CreatorCallback creator = GetCreator();
+ QueueInterface* queue;
+ TF_RETURN_IF_ERROR(
+ cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
+ cinfo_.container(), cinfo_.name(), &queue, creator));
+ core::ScopedUnref unref_me(queue);
+ // Verify that the shared queue is compatible with the requested arguments.
+ TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
+ auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ queue_handle_set_ = true;
+ return Status::OK();
+ }
+
+ mutex mu_;
+ PersistentTensor queue_handle_ GUARDED_BY(mu_);
+ bool queue_handle_set_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index da391224af..e9b6ead381 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/philox_random.h"
@@ -404,17 +405,10 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
// backed by RandomShuffleQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
-class RandomShuffleQueueOp : public OpKernel {
+class RandomShuffleQueueOp : public QueueOp {
public:
explicit RandomShuffleQueueOp(OpKernelConstruction* context)
- : OpKernel(context), queue_handle_set_(false) {
- OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &queue_handle_, nullptr));
- if (capacity_ < 0) {
- capacity_ = RandomShuffleQueue::kUnbounded;
- }
+ : QueueOp(context) {
OP_REQUIRES_OK(context,
context->GetAttr("min_after_dequeue", &min_after_dequeue_));
OP_REQUIRES(context, min_after_dequeue_ >= 0,
@@ -427,32 +421,12 @@ class RandomShuffleQueueOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
- OP_REQUIRES_OK(context,
- context->GetAttr("component_types", &component_types_));
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}
- ~RandomShuffleQueueOp() override {
- // If the queue object was not shared, delete it.
- if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
- TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
- cinfo_.container(), cinfo_.name()));
- }
- }
-
- void Compute(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
- if (!queue_handle_set_) {
- OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
- }
- ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
- }
-
- private:
- Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
- QueueInterface* queue;
- auto creator = [this](QueueInterface** ret) {
+ protected:
+ CreatorCallback GetCreator() const override {
+ return [this](QueueInterface** ret) {
auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_,
seed2_, component_types_,
component_shapes_, cinfo_.name());
@@ -464,30 +438,13 @@ class RandomShuffleQueueOp : public OpKernel {
}
return s;
};
- TF_RETURN_IF_ERROR(
- cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
- cinfo_.container(), cinfo_.name(), &queue, creator));
- core::ScopedUnref unref_me(queue);
- // Verify that the shared queue is compatible with the requested arguments.
- TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
- auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
- queue_handle_set_ = true;
- return Status::OK();
}
- int32 capacity_;
+ private:
int32 min_after_dequeue_;
int64 seed_;
int64 seed2_;
- DataTypeVector component_types_;
std::vector<TensorShape> component_shapes_;
- ContainerInfo cinfo_;
-
- mutex mu_;
- PersistentTensor queue_handle_ GUARDED_BY(mu_);
- bool queue_handle_set_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp);
};